import torch import torch.nn as nn
class Convnet(nn.Module): def init(self): super(Convnet, self).init() # CNN architecture here # 3@12x15 -> 4@12x16 -> 4@6x8 self.conv1 = nn.Conv2d(in_channels=3,out_channels=4, stride=1, kernel_size=1 ) self.pool = nn.MaxPool2d(2, 2) # 4@6x8 -> 8@6x8 -> 8@3x4 self.conv2 = nn.Conv2d(in_channels=4,out_channels=8, stride=1, kernel_size=1) self.fc3 = nn.Linear(8 * 4 * 3, 106) self.fc4 = nn.Linear(120, 64) self.fc5 = nn.Linear(64, 32)
def forward(self, x, trafficLight_status, road_status):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = nn.Flatten()(x)
x = F.relu(self.fc3(x))
x = torch.cat((x,trafficLight_status, road_status), dim=1)
x = F.relu(self.fc4(x))
x = self.fc5(x)
return x
Define your DQN model
class QActor(nn.Module): def init(self, cnn_output_size, action_size): super(QActor, self).init() class QActor(nn.Module):
def __init__(self, state_size, action_size, action_parameter_size, hidden_layers=(100,), action_input_layer=0,
output_layer_init_std=None, activation="relu", **kwargs):
super(QActor, self).__init__()
self.state_size = state_size
self.action_size = action_size
self.action_parameter_size = action_parameter_size
self.activation = activation
# create layers
self.layers = nn.ModuleList()
inputSize = self.state_size + self.action_parameter_size
lastHiddenLayerSize = inputSize
if hidden_layers is not None:
nh = len(hidden_layers)
self.layers.append(nn.Linear(inputSize, hidden_layers[0]))
for i in range(1, nh):
self.layers.append(nn.Linear(hidden_layers[i - 1], hidden_layers[i]))
lastHiddenLayerSize = hidden_layers[nh - 1]
self.layers.append(nn.Linear(lastHiddenLayerSize, self.action_size))
# initialise layer weights
for i in range(0, len(self.layers) - 1):
nn.init.kaiming_normal_(self.layers[i].weight, nonlinearity=activation)
nn.init.zeros_(self.layers[i].bias)
if output_layer_init_std is not None:
nn.init.normal_(self.layers[-1].weight, mean=0., std=output_layer_init_std)
# else:
# nn.init.zeros_(self.layers[-1].weight)
nn.init.zeros_(self.layers[-1].bias)
def forward(self, state, trafficLight_status,road_status, action_parameters):
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ConvNet().to(device)
state = model(state, trafficLight_status, road_status)
model = ConvNet().to(device)
state = model(state, trafficLight_status, road_status)
# implement forward
negative_slope = 0.01
x = torch.cat((state, action_parameters), dim=1)
num_layers = len(self.layers)
for i in range(0, num_layers - 1):
if self.activation == "relu":
x = F.relu(self.layers[i](x))
elif self.activation == "leaky_relu":
x = F.leaky_relu(self.layers[i](x), negative_slope)
else:
raise ValueError("Unknown activation function "+str(self.activation))
Q = self.layers[-1](x)
return Q, state
Create instances of the CNN and DQN models
cnn = Convnet() dqn = QActor(state_size, action_size, action_parameter_size, hidden_layers=(100,), action_input_layer=0, output_layer_init_std=None, activation="relu", **kwargs)
Pass the CNN output to the DQN
cnn_output = ConvNet(x, trafficLight_status, road_status) # Pass your input to the CNN
Detach the flattened output from the computation graph
flattened_output = flattened_output.detach().requires_grad_(True)
Pass the detached flattened output to the DQN and compute the Q-values
q_values = dqn(flattened_output)
I tried to update the forward network in QActor network using the below piece of code but the optimization and backtracking of the Convnet seems to not work the parameters are not being updated
def forward(self, state, trafficLight_status,road_status, action_parameters): # Device configuration device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ConvNet().to(device) state = model(state, trafficLight_status, road_status) model = ConvNet().to(device) # Enable gradient tracking for CNN parameter
for param in model.parameters():
param.requires_grad = True
state = model(state, trafficLight_status, road_status)
# Enable gradient tracking for the flattened output
state = state.detach().requires_grad_(True)