Plot graph
This commit is contained in:
parent
352d981f98
commit
32209b3cd6
@ -79,3 +79,46 @@ class Env:
|
||||
|
||||
def is_node_blocked(self, node: int) -> int:
|
||||
return node in self.blocked
|
||||
|
||||
'''
|
||||
Get shortest calcualted path and print it
|
||||
'''
|
||||
def print_shortest_path(self, Q):
|
||||
path = [self.start]
|
||||
state = self.start
|
||||
while state!=self.goal:
|
||||
action_list = list(self.G.neighbors(state))+[state]
|
||||
max_q = float('-inf')
|
||||
for a in action_list:
|
||||
if Q[state][a] >= max_q:
|
||||
max_q = Q[state][a]
|
||||
action = a
|
||||
(_, _, next_state) = self.get_state_reward(state, action)
|
||||
state = next_state
|
||||
path.append(state)
|
||||
self.plot_graph(path)
|
||||
|
||||
'''
|
||||
Plot graph
|
||||
'''
|
||||
def plot_graph(self, path):
|
||||
node_color = []
|
||||
edge_color = []
|
||||
edge_width = []
|
||||
for node in self.G.nodes():
|
||||
if node in path:
|
||||
node_color.append('lightgreen')
|
||||
elif self.is_node_blocked(node):
|
||||
node_color.append('red')
|
||||
else:
|
||||
node_color.append('lightblue')
|
||||
for e in self.G.edges():
|
||||
if e[0] in path and e[1] in path:
|
||||
edge_color.append('lightgreen')
|
||||
edge_width.append(5)
|
||||
else:
|
||||
edge_color.append('black')
|
||||
edge_width.append(1)
|
||||
|
||||
nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width)
|
||||
plt.show()
|
||||
|
Loading…
Reference in New Issue
Block a user