import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
# ------------------------------------ #
# Actor
# ------------------------------------ #
class PolicyNet(nn.Module):
def __init__(self, n_states, n_hiddens, action_dims):
super(PolicyNet, self).__init__()
self.fc1 = nn.Linear(n_states, n_hiddens[0])
self.fc2 = nn.Linear(n_hiddens[0], n_hiddens[1])
self.fc3 = nn.Linear(n_hiddens[1], sum(action_dims))
self.action_dims = action_dims
def forward(self, x):
x = self.fc1(x) # [b,n_states]-->[b,n_hiddens]
x = F.relu(x)
x = self.fc2(x) # [b,n_hiddens]-->[b,n_actions]
x = F.relu(x)
x = self.fc3(x) # [b,n_hiddens]-->[b,n_actions]
out = [F.softmax(x[:, 0:self.action_dims[0]], dim=-1)] \
+ [F.softmax(x[:, dim_start:dim_start + dim], dim=-1)
for dim_start, dim
in zip(torch.cumsum(torch.tensor([0] + self.action_dims[:-1]), dim=0), self.action_dims)]
return out
# ------------------------------------ #
# Critic
# ------------------------------------ #
class ValueNet(nn.Module):
def __init__(self, n_states, n_hiddens):
super(ValueNet, self).__init__()
self.fc1 = nn.Linear(n_states, n_hiddens[0])
self.fc2 = nn.Linear(n_hiddens[0], n_hiddens[1])
self.fc3 = nn.Linear(n_hiddens[1], 1)
def forward(self, x):
x = self.fc1(x) # [b,n_states]-->[b,n_hiddens]
x = F.relu(x)
x = self.fc2(x) # [b,n_hiddens]-->[b,1]
return x
# ------------------------------------ #
# Actor-Critic
# ------------------------------------ #
class ActorCritic:
def __init__(self, n_states, n_hiddens, n_actions,
actor_lr, critic_lr, gamma):
self.gamma = gamma
self.actor = PolicyNet(n_states, n_hiddens, n_actions)
self.critic = ValueNet(n_states, n_hiddens)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)
def take_action(self, state):
state = torch.tensor(state[np.newaxis, :]).to(torch.float32)
action_probs = self.actor(state)
actions = [torch.multinomial(prob, 1, replacement=True).item() for prob in action_probs]
return actions
# model update
def update(self, transition_dict):
# train set
states = torch.tensor(transition_dict['states'], dtype=torch.float)
actions = torch.tensor(transition_dict['actions']) # 300*1(50 * 6)
rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1) # 50*1
next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float)
dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1)
# preidct state_value
td_value = self.critic(states)
# target state_value
td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)
# td-error
td_delta = td_target - td_value
# policy gradient loss (todo!!!!)
output_probs = self.actor(states)
log_probs = [torch.log(prob) for prob in output_probs]
selected_log_probs = []
for i in range(len(output_probs)):
selected_log_probs.append(log_probs[i].gather(1, actions[:, i].unsqueeze(1))) # action
actor_loss = torch.mean(-sum(selected_log_probs) * td_delta.detach())
# policy gradient loss
# log_probs = torch.log(self.actor(states).gather(1, actions))
# actor_loss = torch.mean(-log_probs * td_delta.detach())
# value function loss between predict value and target value
critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
actor_loss.backward()
critic_loss.backward()
self.actor_optimizer.step()
self.critic_optimizer.step()
As written in the above code, I construct a simple actor critic network and use multiple softmax
on the actor's output to be able to output the probability distribution of multiple actions in various dimensions. However, I don't know how to update the policy network (where the todo
is marked) and I have figured out a way(like the code shows) but I'm not sure if it's right. Can anybody help me check it.
To be clear, the action space dimension is like [2,2], which represent that there is 2 actions and each action have two options. And the actor output 4 values and the first two values are the probability distribution of the first action, and the last two values are the probability distribution of the second action.