Not converge- Simple Actor Critic for Multi-discrete Action Space

46 views Asked by At
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np

# ------------------------------------ #
# Actor
# ------------------------------------ #

class PolicyNet(nn.Module):
    def __init__(self, n_states, n_hiddens, action_dims):
        super(PolicyNet, self).__init__()
        self.fc1 = nn.Linear(n_states, n_hiddens[0])
        self.fc2 = nn.Linear(n_hiddens[0], n_hiddens[1])
        self.fc3 = nn.Linear(n_hiddens[1], sum(action_dims))
        self.action_dims = action_dims

    def forward(self, x):
        x = self.fc1(x)  # [b,n_states]-->[b,n_hiddens]
        x = F.relu(x)
        x = self.fc2(x)  # [b,n_hiddens]-->[b,n_actions]
        x = F.relu(x)
        x = self.fc3(x)  # [b,n_hiddens]-->[b,n_actions]

        out = [F.softmax(x[:, 0:self.action_dims[0]], dim=-1)] \
              + [F.softmax(x[:, dim_start:dim_start + dim], dim=-1)
                 for dim_start, dim
                 in zip(torch.cumsum(torch.tensor([0] + self.action_dims[:-1]), dim=0), self.action_dims)]
        return out


# ------------------------------------ #
# Critic
# ------------------------------------ #

class ValueNet(nn.Module):
    def __init__(self, n_states, n_hiddens):
        super(ValueNet, self).__init__()
        self.fc1 = nn.Linear(n_states, n_hiddens[0])
        self.fc2 = nn.Linear(n_hiddens[0], n_hiddens[1])
        self.fc3 = nn.Linear(n_hiddens[1], 1)

    def forward(self, x):
        x = self.fc1(x)  # [b,n_states]-->[b,n_hiddens]
        x = F.relu(x)
        x = self.fc2(x)  # [b,n_hiddens]-->[b,1]
        return x


# ------------------------------------ #
# Actor-Critic
# ------------------------------------ #

class ActorCritic:
    def __init__(self, n_states, n_hiddens, n_actions,
                 actor_lr, critic_lr, gamma):
        self.gamma = gamma
        self.actor = PolicyNet(n_states, n_hiddens, n_actions)
        self.critic = ValueNet(n_states, n_hiddens)
        self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=actor_lr)
        self.critic_optimizer = torch.optim.Adam(self.critic.parameters(), lr=critic_lr)

    def take_action(self, state):
        state = torch.tensor(state[np.newaxis, :]).to(torch.float32)
        action_probs = self.actor(state)
        actions = [torch.multinomial(prob, 1, replacement=True).item() for prob in action_probs]
        return actions

    # model update
    def update(self, transition_dict):
        # train set
        states = torch.tensor(transition_dict['states'], dtype=torch.float)
        actions = torch.tensor(transition_dict['actions'])  # 300*1(50 * 6)
        rewards = torch.tensor(transition_dict['rewards'], dtype=torch.float).view(-1, 1)  # 50*1
        next_states = torch.tensor(transition_dict['next_states'], dtype=torch.float)
        dones = torch.tensor(transition_dict['dones'], dtype=torch.float).view(-1, 1)

        # preidct state_value
        td_value = self.critic(states)

        # target state_value
        td_target = rewards + self.gamma * self.critic(next_states) * (1 - dones)

        # td-error
        td_delta = td_target - td_value  
        
        # policy gradient loss (todo!!!!)
        output_probs = self.actor(states)  
        log_probs = [torch.log(prob) for prob in output_probs] 
        selected_log_probs = []
        for i in range(len(output_probs)):  
            selected_log_probs.append(log_probs[i].gather(1, actions[:, i].unsqueeze(1)))  # action
        actor_loss = torch.mean(-sum(selected_log_probs) * td_delta.detach())

        # policy gradient loss
        # log_probs = torch.log(self.actor(states).gather(1, actions))
        # actor_loss = torch.mean(-log_probs * td_delta.detach())

        # value function loss between predict value and target value
        critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))

        self.actor_optimizer.zero_grad()
        self.critic_optimizer.zero_grad()

        actor_loss.backward()
        critic_loss.backward()

        self.actor_optimizer.step()
        self.critic_optimizer.step()

As written in the above code, I construct a simple actor critic network and use multiple softmax on the actor's output to be able to output the probability distribution of multiple actions in various dimensions. However, I don't know how to update the policy network (where the todo is marked) and I have figured out a way(like the code shows) but I'm not sure if it's right. Can anybody help me check it.

To be clear, the action space dimension is like [2,2], which represent that there is 2 actions and each action have two options. And the actor output 4 values and the first two values ​​are the probability distribution of the first action, and the last two values ​​are the probability distribution of the second action.

0

There are 0 answers