import typing import networkx as nx import matplotlib.pyplot as plt ''' Build network with networkx library ''' def build_network(): G = nx.path_graph(11) G.remove_edges_from(G.edges()) G.add_edges_from([(0, 3), (0, 4), (1, 2), (1, 4), (1, 5), (1, 8), (1, 9), (2, 3), (2, 5), (2, 6), (5, 6), (5, 7), (7, 8), (8, 9), (8, 10), (9, 10)]) nx.draw(G, with_labels=True) plt.show() return G ''' Define Environment class - G: networkx graph - start: start-element for agent - goal: goal-element for agent - blocked: list of blocked nodes ''' class Env: def __init__(self, start: int, goal: int): self.G = build_network() self.start = start self.goal = goal self.blocked = [] ''' Reset the environment --> no blocked nodes ''' def reset(self) -> None: self.blocked = [] ''' Returns the next state and the reward based on state-action pait ''' def get_state_reward(self, state: int, action: int) -> typing.Tuple[bool, int, int]: if action == self.goal: return True, 30, action elif self.is_node_blocked(action): return False, -5, state else: return False, -1, action ''' Blocks node ''' def block_node(self, node: int) -> None: self.blocked.append(node) ''' Unblocks node ''' def release_node(self, node: int) -> None: self.blocked.remove(node) ''' Return True if node is blocked ''' def is_node_blocked(self, node: int) -> int: return node in self.blocked ''' Get shortest calcualted path and print it ''' def print_shortest_path(self, Q): path = [self.start] state = self.start while state!=self.goal: action_list = list(self.G.neighbors(state))+[state] max_q = float('-inf') for a in action_list: if Q[state][a] >= max_q: max_q = Q[state][a] action = a (_, _, next_state) = self.get_state_reward(state, action) state = next_state path.append(state) self.plot_graph(path) ''' Plot graph ''' def plot_graph(self, path): node_color = [] edge_color = [] edge_width = [] for node in self.G.nodes(): if node in path: node_color.append('lightgreen') elif self.is_node_blocked(node): node_color.append('red') else: node_color.append('lightblue') for e in self.G.edges(): if e[0] in path and e[1] in path: edge_color.append('lightgreen') edge_width.append(5) else: edge_color.append('black') edge_width.append(1) nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width) plt.show()