diff --git a/Environment.py b/Environment.py index 78f6d6d..c205ca6 100644 --- a/Environment.py +++ b/Environment.py @@ -79,3 +79,46 @@ class Env: def is_node_blocked(self, node: int) -> int: return node in self.blocked + + ''' + Get shortest calcualted path and print it + ''' + def print_shortest_path(self, Q): + path = [self.start] + state = self.start + while state!=self.goal: + action_list = list(self.G.neighbors(state))+[state] + max_q = float('-inf') + for a in action_list: + if Q[state][a] >= max_q: + max_q = Q[state][a] + action = a + (_, _, next_state) = self.get_state_reward(state, action) + state = next_state + path.append(state) + self.plot_graph(path) + + ''' + Plot graph + ''' + def plot_graph(self, path): + node_color = [] + edge_color = [] + edge_width = [] + for node in self.G.nodes(): + if node in path: + node_color.append('lightgreen') + elif self.is_node_blocked(node): + node_color.append('red') + else: + node_color.append('lightblue') + for e in self.G.edges(): + if e[0] in path and e[1] in path: + edge_color.append('lightgreen') + edge_width.append(5) + else: + edge_color.append('black') + edge_width.append(1) + + nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width) + plt.show()