2
0
Fork 0
graph/Environment.py

82 lines
1.7 KiB
Python

import typing
import networkx as nx
import matplotlib.pyplot as plt
'''
Build network with networkx library
'''
def build_network():
G = nx.path_graph(11)
G.remove_edges_from(G.edges())
G.add_edges_from([(0, 3), (0, 4),
(1, 2), (1, 4), (1, 5), (1, 8), (1, 9),
(2, 3), (2, 5), (2, 6),
(5, 6), (5, 7),
(7, 8),
(8, 9), (8, 10),
(9, 10)])
nx.draw(G, with_labels=True)
plt.show()
return G
'''
Define Environment class
- G: networkx graph
- start: start-element for agent
- goal: goal-element for agent
- blocked: list of blocked nodes
'''
class Env:
def __init__(self, start: int, goal: int):
self.G = build_network()
self.start = start
self.goal = goal
self.blocked = []
'''
Reset the environment --> no blocked nodes
'''
def reset(self) -> None:
self.blocked = []
'''
Returns the next state and the reward based on state-action pait
'''
def get_state_reward(self, state: int, action: int) -> typing.Tuple[bool, int, int]:
if action == self.goal:
return True, 30, action
elif self.is_node_blocked(action):
return False, -5, state
else:
return False, -1, action
'''
Blocks node
'''
def block_node(self, node: int) -> None:
self.blocked.append(node)
'''
Unblocks node
'''
def release_node(self, node: int) -> None:
self.blocked.remove(node)
'''
Return True if node is blocked
'''
def is_node_blocked(self, node: int) -> int:
return node in self.blocked