43 lines
1.3 KiB
Python
43 lines
1.3 KiB
Python
|
import torch
|
||
|
from gym import Env
|
||
|
from torch import optim
|
||
|
|
||
|
from model import ActorCritic
|
||
|
|
||
|
# Parameters for learning
|
||
|
gamma = 0.99
|
||
|
lr = 0.02
|
||
|
betas = (0.9, 0.999)
|
||
|
|
||
|
|
||
|
def train(environment: Env, policy: ActorCritic):
|
||
|
optimizer = optim.Adam(policy.parameters(), lr=lr, betas=betas)
|
||
|
running_reward = 0
|
||
|
for i_episode in range(0, 5000):
|
||
|
state = environment.reset()
|
||
|
for t in range(10000):
|
||
|
action = policy(state)
|
||
|
state, reward, done, _ = environment.step(action)
|
||
|
policy.rewards.append(reward)
|
||
|
running_reward += reward
|
||
|
if done:
|
||
|
break
|
||
|
# Updating the policy :
|
||
|
optimizer.zero_grad()
|
||
|
loss = policy.calculate_loss(gamma)
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
policy.clear_memory()
|
||
|
|
||
|
if running_reward > 4000:
|
||
|
torch.save(policy.state_dict(), './LunarLander.pth'.format(lr, betas[0], betas[1]))
|
||
|
print("########## Solved! ##########")
|
||
|
break
|
||
|
|
||
|
if i_episode % 20 == 0:
|
||
|
running_reward = running_reward / 20
|
||
|
print('Episode {}\tlength: {}\treward: {}'.format(i_episode, t, running_reward))
|
||
|
running_reward = 0
|
||
|
|
||
|
# We safe a checkpoint anyway to ensure we have a something.
|
||
|
torch.save(policy.state_dict(), './LunarLander.pth'.format(lr, betas[0], betas[1]))
|