Plot graph

This commit is contained in:
drothschedl 2022-07-25 11:13:46 +02:00
parent 502c246f30
commit 86a73b8433

View File

@ -79,3 +79,48 @@ class Env:
def is_node_blocked(self, node: int) -> int:
return node in self.blocked
'''
Get shortest calcualted path and print it
'''
def print_shortest_path(self, Q):
path = [self.start]
state = self.start
while state!=self.goal:
action_list = list(self.G.neighbors(state))+[state]
max_q = float('-inf')
for a in action_list:
if Q[state][a] >= max_q:
max_q = Q[state][a]
action = a
(_, _, next_state) = self.get_state_reward(state, action)
state = next_state
path.append(state)
self.plot_graph(path)
'''
Plot graph
'''
def plot_graph(self, path):
node_color = []
edge_color = []
edge_width = []
for node in self.G.nodes():
if node in path:
node_color.append('lightgreen')
elif self.is_node_blocked(node):
node_color.append('red')
else:
node_color.append('lightblue')
for e in self.G.edges():
if e[0] in path and e[1] in path:
edge_color.append('lightgreen')
edge_width.append(5)
else:
edge_color.append('black')
edge_width.append(1)
nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width)
plt.show()