From 86a73b8433c8c4c6ef8154892e02ce2f1d5652cd Mon Sep 17 00:00:00 2001 From: drothschedl Date: Mon, 25 Jul 2022 11:13:46 +0200 Subject: [PATCH] Plot graph --- Environment.py | 45 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/Environment.py b/Environment.py index 78f6d6d..8722e3f 100644 --- a/Environment.py +++ b/Environment.py @@ -79,3 +79,48 @@ class Env: def is_node_blocked(self, node: int) -> int: return node in self.blocked + + ''' + Get shortest calcualted path and print it + ''' + + def print_shortest_path(self, Q): + path = [self.start] + state = self.start + while state!=self.goal: + action_list = list(self.G.neighbors(state))+[state] + max_q = float('-inf') + for a in action_list: + if Q[state][a] >= max_q: + max_q = Q[state][a] + action = a + (_, _, next_state) = self.get_state_reward(state, action) + state = next_state + path.append(state) + self.plot_graph(path) + + ''' + Plot graph + ''' + + def plot_graph(self, path): + node_color = [] + edge_color = [] + edge_width = [] + for node in self.G.nodes(): + if node in path: + node_color.append('lightgreen') + elif self.is_node_blocked(node): + node_color.append('red') + else: + node_color.append('lightblue') + for e in self.G.edges(): + if e[0] in path and e[1] in path: + edge_color.append('lightgreen') + edge_width.append(5) + else: + edge_color.append('black') + edge_width.append(1) + + nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width) + plt.show()