31 lines
725 B
Python
31 lines
725 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch.distributions import Categorical
|
|
|
|
|
|
class ActorCritic(nn.Module):
|
|
def __init__(self, seed: int = 12345):
|
|
super(ActorCritic, self).__init__()
|
|
torch.random.manual_seed(seed)
|
|
# here we need the AC Network
|
|
|
|
self.logprobs = []
|
|
self.state_values = []
|
|
self.rewards = []
|
|
pass
|
|
|
|
def forward(self, state):
|
|
# Here we need to evaluate the AC Network
|
|
pass
|
|
|
|
def calculate_loss(self, gamma: float = 0.99):
|
|
# calculating discounted rewards
|
|
|
|
pass
|
|
|
|
def clear_memory(self):
|
|
del self.logprobs[:]
|
|
del self.state_values[:]
|
|
del self.rewards[:]
|