127 lines
3.1 KiB
Python
127 lines
3.1 KiB
Python
import typing
|
|
|
|
import networkx as nx
|
|
import matplotlib.pyplot as plt
|
|
|
|
'''
|
|
Build network with networkx library
|
|
'''
|
|
|
|
|
|
def build_network():
|
|
G = nx.path_graph(11)
|
|
G.remove_edges_from(G.edges())
|
|
G.add_edges_from([(0, 3), (0, 4),
|
|
(1, 2), (1, 4), (1, 5), (1, 8), (1, 9),
|
|
(2, 3), (2, 5), (2, 6),
|
|
(5, 6), (5, 7),
|
|
(7, 8),
|
|
(8, 9), (8, 10),
|
|
(9, 10)])
|
|
nx.draw(G, with_labels=True)
|
|
plt.show()
|
|
|
|
return G
|
|
|
|
|
|
'''
|
|
Define Environment class
|
|
- G: networkx graph
|
|
- start: start-element for agent
|
|
- goal: goal-element for agent
|
|
- blocked: list of blocked nodes
|
|
'''
|
|
|
|
|
|
class Env:
|
|
def __init__(self, start: int, goal: int):
|
|
self.G = build_network()
|
|
self.start = start
|
|
self.goal = goal
|
|
self.blocked = []
|
|
|
|
'''
|
|
Reset the environment --> no blocked nodes
|
|
'''
|
|
|
|
def reset(self) -> None:
|
|
self.blocked = []
|
|
|
|
'''
|
|
Returns the next state and the reward based on state-action pait
|
|
'''
|
|
|
|
def get_state_reward(self, state: int, action: int) -> typing.Tuple[bool, int, int]:
|
|
if action == self.goal:
|
|
return True, 30, action
|
|
elif self.is_node_blocked(action):
|
|
return False, -5, state
|
|
else:
|
|
return False, -1, action
|
|
|
|
'''
|
|
Blocks node
|
|
'''
|
|
|
|
def block_node(self, node: int) -> None:
|
|
self.blocked.append(node)
|
|
|
|
'''
|
|
Unblocks node
|
|
'''
|
|
|
|
def release_node(self, node: int) -> None:
|
|
self.blocked.remove(node)
|
|
|
|
'''
|
|
Return True if node is blocked
|
|
'''
|
|
|
|
def is_node_blocked(self, node: int) -> int:
|
|
return node in self.blocked
|
|
|
|
'''
|
|
Get shortest calcualted path and print it
|
|
'''
|
|
|
|
def print_shortest_path(self, Q):
|
|
path = [self.start]
|
|
state = self.start
|
|
while state!=self.goal:
|
|
action_list = list(self.G.neighbors(state))+[state]
|
|
max_q = float('-inf')
|
|
for a in action_list:
|
|
if Q[state][a] >= max_q:
|
|
max_q = Q[state][a]
|
|
action = a
|
|
(_, _, next_state) = self.get_state_reward(state, action)
|
|
state = next_state
|
|
path.append(state)
|
|
self.plot_graph(path)
|
|
|
|
'''
|
|
Plot graph
|
|
'''
|
|
|
|
def plot_graph(self, path):
|
|
node_color = []
|
|
edge_color = []
|
|
edge_width = []
|
|
for node in self.G.nodes():
|
|
if node in path:
|
|
node_color.append('lightgreen')
|
|
elif self.is_node_blocked(node):
|
|
node_color.append('red')
|
|
else:
|
|
node_color.append('lightblue')
|
|
for e in self.G.edges():
|
|
if e[0] in path and e[1] in path:
|
|
edge_color.append('lightgreen')
|
|
edge_width.append(5)
|
|
else:
|
|
edge_color.append('black')
|
|
edge_width.append(1)
|
|
|
|
nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width)
|
|
plt.show()
|