import torch import torch.nn as nn import torch.nn.functional as F from torch.distributions import Categorical class ActorCritic(nn.Module): def __init__(self, seed: int = 12345): super(ActorCritic, self).__init__() torch.random.manual_seed(seed) # here we need the AC Network self.logprobs = [] self.state_values = [] self.rewards = [] pass def forward(self, state): # Here we need to evaluate the AC Network pass def calculate_loss(self, gamma: float = 0.99): # calculating discounted rewards pass def clear_memory(self): del self.logprobs[:] del self.state_values[:] del self.rewards[:]