Plot graph
This commit is contained in:
parent
352d981f98
commit
32209b3cd6
@ -79,3 +79,46 @@ class Env:
|
|||||||
|
|
||||||
def is_node_blocked(self, node: int) -> int:
|
def is_node_blocked(self, node: int) -> int:
|
||||||
return node in self.blocked
|
return node in self.blocked
|
||||||
|
|
||||||
|
'''
|
||||||
|
Get shortest calcualted path and print it
|
||||||
|
'''
|
||||||
|
def print_shortest_path(self, Q):
|
||||||
|
path = [self.start]
|
||||||
|
state = self.start
|
||||||
|
while state!=self.goal:
|
||||||
|
action_list = list(self.G.neighbors(state))+[state]
|
||||||
|
max_q = float('-inf')
|
||||||
|
for a in action_list:
|
||||||
|
if Q[state][a] >= max_q:
|
||||||
|
max_q = Q[state][a]
|
||||||
|
action = a
|
||||||
|
(_, _, next_state) = self.get_state_reward(state, action)
|
||||||
|
state = next_state
|
||||||
|
path.append(state)
|
||||||
|
self.plot_graph(path)
|
||||||
|
|
||||||
|
'''
|
||||||
|
Plot graph
|
||||||
|
'''
|
||||||
|
def plot_graph(self, path):
|
||||||
|
node_color = []
|
||||||
|
edge_color = []
|
||||||
|
edge_width = []
|
||||||
|
for node in self.G.nodes():
|
||||||
|
if node in path:
|
||||||
|
node_color.append('lightgreen')
|
||||||
|
elif self.is_node_blocked(node):
|
||||||
|
node_color.append('red')
|
||||||
|
else:
|
||||||
|
node_color.append('lightblue')
|
||||||
|
for e in self.G.edges():
|
||||||
|
if e[0] in path and e[1] in path:
|
||||||
|
edge_color.append('lightgreen')
|
||||||
|
edge_width.append(5)
|
||||||
|
else:
|
||||||
|
edge_color.append('black')
|
||||||
|
edge_width.append(1)
|
||||||
|
|
||||||
|
nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width)
|
||||||
|
plt.show()
|
||||||
|
Loading…
Reference in New Issue
Block a user