2
0
Fork 0
lander/model.py

31 lines
725 B
Python
Raw Permalink Normal View History

2022-07-22 11:49:30 +02:00
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
class ActorCritic(nn.Module):
def __init__(self, seed: int = 12345):
super(ActorCritic, self).__init__()
torch.random.manual_seed(seed)
2022-07-22 12:03:19 +02:00
# here we need the AC Network
2022-07-22 11:49:30 +02:00
self.logprobs = []
self.state_values = []
self.rewards = []
2022-07-22 12:03:19 +02:00
pass
2022-07-22 11:49:30 +02:00
def forward(self, state):
2022-07-22 12:03:19 +02:00
# Here we need to evaluate the AC Network
pass
2022-07-22 11:49:30 +02:00
def calculate_loss(self, gamma: float = 0.99):
2022-07-22 12:03:19 +02:00
# calculating discounted rewards
2022-07-22 11:49:30 +02:00
2022-07-22 12:03:19 +02:00
pass
2022-07-22 11:49:30 +02:00
def clear_memory(self):
del self.logprobs[:]
del self.state_values[:]
del self.rewards[:]