import torch from PIL import Image from gym import Env from model import ActorCritic def run(env: Env, policy: ActorCritic=None, n_episodes=5, name='LunarLander.pth'): if policy is None: policy = ActorCritic() policy.load_state_dict(torch.load('./{}'.format(name))) render = True save_gif = False for i_episode in range(1, n_episodes + 1): state = env.reset() running_reward = 0 for t in range(10000): action = policy(state) state, reward, done, _ = env.step(action) running_reward += reward if render: env.render() if save_gif: img = env.render(mode='rgb_array') img = Image.fromarray(img) img.save('./{}.jpg'.format(t)) if done: break print('Episode {}\tReward: {}'.format(i_episode, running_reward)) env.close()