You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
34 lines
957 B
Python
34 lines
957 B
Python
import torch
|
|
from PIL import Image
|
|
from gym import Env
|
|
|
|
from model import ActorCritic
|
|
|
|
|
|
def run(env: Env, policy: ActorCritic=None, n_episodes=5, name='LunarLander.pth'):
|
|
if policy is None:
|
|
policy = ActorCritic()
|
|
policy.load_state_dict(torch.load('./{}'.format(name)))
|
|
|
|
render = True
|
|
save_gif = False
|
|
|
|
for i_episode in range(1, n_episodes + 1):
|
|
state = env.reset()
|
|
running_reward = 0
|
|
for t in range(10000):
|
|
action = policy(state)
|
|
state, reward, done, _ = env.step(action)
|
|
running_reward += reward
|
|
if render:
|
|
env.render()
|
|
if save_gif:
|
|
img = env.render(mode='rgb_array')
|
|
img = Image.fromarray(img)
|
|
img.save('./{}.jpg'.format(t))
|
|
if done:
|
|
break
|
|
print('Episode {}\tReward: {}'.format(i_episode, running_reward))
|
|
env.close()
|
|
|