graph/Environment.py

127 lines
3.1 KiB
Python
Raw Normal View History

2022-07-22 13:44:29 +02:00
import typing
import networkx as nx
import matplotlib.pyplot as plt
'''
Build network with networkx library
'''
def build_network():
G = nx.path_graph(11)
G.remove_edges_from(G.edges())
G.add_edges_from([(0, 3), (0, 4),
(1, 2), (1, 4), (1, 5), (1, 8), (1, 9),
(2, 3), (2, 5), (2, 6),
(5, 6), (5, 7),
(7, 8),
(8, 9), (8, 10),
(9, 10)])
nx.draw(G, with_labels=True)
plt.show()
return G
'''
Define Environment class
- G: networkx graph
- start: start-element for agent
- goal: goal-element for agent
- blocked: list of blocked nodes
'''
class Env:
def __init__(self, start: int, goal: int):
self.G = build_network()
self.start = start
self.goal = goal
self.blocked = []
'''
Reset the environment --> no blocked nodes
'''
def reset(self) -> None:
self.blocked = []
'''
Returns the next state and the reward based on state-action pait
'''
def get_state_reward(self, state: int, action: int) -> typing.Tuple[bool, int, int]:
if action == self.goal:
return True, 30, action
elif self.is_node_blocked(action):
return False, -5, state
else:
return False, -1, action
'''
Blocks node
'''
def block_node(self, node: int) -> None:
self.blocked.append(node)
'''
Unblocks node
'''
def release_node(self, node: int) -> None:
self.blocked.remove(node)
'''
Return True if node is blocked
'''
def is_node_blocked(self, node: int) -> int:
return node in self.blocked
2022-07-25 11:13:46 +02:00
'''
Get shortest calcualted path and print it
'''
def print_shortest_path(self, Q):
path = [self.start]
state = self.start
while state!=self.goal:
action_list = list(self.G.neighbors(state))+[state]
max_q = float('-inf')
for a in action_list:
if Q[state][a] >= max_q:
max_q = Q[state][a]
action = a
(_, _, next_state) = self.get_state_reward(state, action)
state = next_state
path.append(state)
self.plot_graph(path)
'''
Plot graph
'''
def plot_graph(self, path):
node_color = []
edge_color = []
edge_width = []
for node in self.G.nodes():
if node in path:
node_color.append('lightgreen')
elif self.is_node_blocked(node):
node_color.append('red')
else:
node_color.append('lightblue')
for e in self.G.edges():
if e[0] in path and e[1] in path:
edge_color.append('lightgreen')
edge_width.append(5)
else:
edge_color.append('black')
edge_width.append(1)
nx.draw(self.G, pos=self.pos, with_labels=True, node_color=node_color, edge_color=edge_color, width=edge_width)
plt.show()