82 lines
1.7 KiB
Python
82 lines
1.7 KiB
Python
|
import typing
|
||
|
|
||
|
import networkx as nx
|
||
|
import matplotlib.pyplot as plt
|
||
|
|
||
|
'''
|
||
|
Build network with networkx library
|
||
|
'''
|
||
|
|
||
|
|
||
|
def build_network():
|
||
|
G = nx.path_graph(11)
|
||
|
G.remove_edges_from(G.edges())
|
||
|
G.add_edges_from([(0, 3), (0, 4),
|
||
|
(1, 2), (1, 4), (1, 5), (1, 8), (1, 9),
|
||
|
(2, 3), (2, 5), (2, 6),
|
||
|
(5, 6), (5, 7),
|
||
|
(7, 8),
|
||
|
(8, 9), (8, 10),
|
||
|
(9, 10)])
|
||
|
nx.draw(G, with_labels=True)
|
||
|
plt.show()
|
||
|
|
||
|
return G
|
||
|
|
||
|
|
||
|
'''
|
||
|
Define Environment class
|
||
|
- G: networkx graph
|
||
|
- start: start-element for agent
|
||
|
- goal: goal-element for agent
|
||
|
- blocked: list of blocked nodes
|
||
|
'''
|
||
|
|
||
|
|
||
|
class Env:
|
||
|
def __init__(self, start: int, goal: int):
|
||
|
self.G = build_network()
|
||
|
self.start = start
|
||
|
self.goal = goal
|
||
|
self.blocked = []
|
||
|
|
||
|
'''
|
||
|
Reset the environment --> no blocked nodes
|
||
|
'''
|
||
|
|
||
|
def reset(self) -> None:
|
||
|
self.blocked = []
|
||
|
|
||
|
'''
|
||
|
Returns the next state and the reward based on state-action pait
|
||
|
'''
|
||
|
|
||
|
def get_state_reward(self, state: int, action: int) -> typing.Tuple[bool, int, int]:
|
||
|
if action == self.goal:
|
||
|
return True, 30, action
|
||
|
elif self.is_node_blocked(action):
|
||
|
return False, -5, state
|
||
|
else:
|
||
|
return False, -1, action
|
||
|
|
||
|
'''
|
||
|
Blocks node
|
||
|
'''
|
||
|
|
||
|
def block_node(self, node: int) -> None:
|
||
|
self.blocked.append(node)
|
||
|
|
||
|
'''
|
||
|
Unblocks node
|
||
|
'''
|
||
|
|
||
|
def release_node(self, node: int) -> None:
|
||
|
self.blocked.remove(node)
|
||
|
|
||
|
'''
|
||
|
Return True if node is blocked
|
||
|
'''
|
||
|
|
||
|
def is_node_blocked(self, node: int) -> int:
|
||
|
return node in self.blocked
|