lander/train.py

43 lines
1.3 KiB
Python
Raw Normal View History

2022-07-22 11:49:30 +02:00
import torch
from gym import Env
from torch import optim
from model import ActorCritic
# Parameters for learning
gamma = 0.99
lr = 0.02
betas = (0.9, 0.999)
def train(environment: Env, policy: ActorCritic):
optimizer = optim.Adam(policy.parameters(), lr=lr, betas=betas)
running_reward = 0
for i_episode in range(0, 5000):
state = environment.reset()
for t in range(10000):
action = policy(state)
state, reward, done, _ = environment.step(action)
policy.rewards.append(reward)
running_reward += reward
if done:
break
# Updating the policy :
optimizer.zero_grad()
loss = policy.calculate_loss(gamma)
loss.backward()
optimizer.step()
policy.clear_memory()
if running_reward > 4000:
torch.save(policy.state_dict(), './LunarLander.pth'.format(lr, betas[0], betas[1]))
print("########## Solved! ##########")
break
if i_episode % 20 == 0:
running_reward = running_reward / 20
print('Episode {}\tlength: {}\treward: {}'.format(i_episode, t, running_reward))
running_reward = 0
# We safe a checkpoint anyway to ensure we have a something.
torch.save(policy.state_dict(), './LunarLander.pth'.format(lr, betas[0], betas[1]))