I'm trying to use a A3TGCN2 for traffic prediction, but I've had a hard time to understand how to set it up. I thought that the way of defining this module would follow the same logic as other modules from this library, such as GConvLSTM, but apparently the A3TGCN2 module doesn't accept batched edge_index
tensors.
The documentation states that this implementation specifically allows for the use of batched tensors, but when I try to run it with batched tensors, it simply breaks my code:
class MyModel(nn.Module):
def __init__(self,
features: int,
out_dim: int,
batch_size: int,
periods: int,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
):
super(Predictor, self).__init__()
self.features= features
self.out_dim= out_dim
self.batch_size = batch_size
self.device = device
self.periods = periods
self.tgnn = A3TGCN2(
in_channels=self.features,
out_channels=self.out_dim,
periods=self.periods,
batch_size=self.batch_size,
)
def forward(self, x, edge_index):
'''
Parameters
------------
x: torch.Tensor
node features, of shape (seq_len, num_nodes*batch_size, features)
edge_index: torch.Tensor
edge_indices, of shape (2, batch_size*num_edges)
'''
seq_len, num_nodes, features= x.shape
x = torch.movedim(x, 0, -1)
x = x.reshape(self.batch_size, -1, features, seq_len)
# now x is shaped (batch_size, num_nodes, features, seq_len)
H = self.tgnn(X=x, edge_index=edge_index)
return H
The documentation states that x must be of shape (batch_size, num_nodes, features, seq_len), but this doesn't seem to work properly, as I'm yielding the following error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[49], line 80
78 features = encoder(x, edge_index)
79 features = GRLayer(features)
---> 80 y_pred = pred(features, edge_index, batch)
83 loss = criterion(y_pred, y)
84 total_loss += loss
File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
Cell In[48], line 32, in Predictor.forward(self, x, edge_index, batch)
29 print(x.shape)
31 # edge_index = self.unbatch_edge_index(edge_index, batch)
---> 32 H = self.tgnn(X=x, edge_index=edge_index)
33 print(H.shape)
35 return H
File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\t4c\lib\site-packages\torch_geometric_temporal\nn\recurrent\attentiontemporalgcn.py:155, in A3TGCN2.forward(self, X, edge_index, edge_weight, H)
152 probs = torch.nn.functional.softmax(self._attention, dim=0)
153 for period in range(self.periods):
--> 155 H_accum = H_accum + probs[period] * self._base_tgcn( X[:, :, :, period], edge_index, edge_weight, H) #([32, 207, 32]
157 return H_accum
File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~\anaconda3\envs\t4c\lib\site-packages\torch_geometric_temporal\nn\recurrent\temporalgcn.py:229, in TGCN2.forward(self, X, edge_index, edge_weight, H)
214 """
215 Making a forward pass. If edge weights are not present the forward pass
216 defaults to an unweighted graph. If the hidden state matrix is not present
(...)
226 * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
227 """
228 H = self._set_hidden_state(X, H)
--> 229 Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
230 R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
231 H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)
File ~\anaconda3\envs\t4c\lib\site-packages\torch_geometric_temporal\nn\recurrent\temporalgcn.py:188, in TGCN2._calculate_update_gate(self, X, edge_index, edge_weight, H)
187 def _calculate_update_gate(self, X, edge_index, edge_weight, H):
--> 188 Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=2) # (b, 207, 64)
189 Z = self.linear_z(Z) # (b, 207, 32)
190 Z = torch.sigmoid(Z)
File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~\AppData\Roaming\Python\Python38\site-packages\torch_geometric\nn\conv\gcn_conv.py:210, in GCNConv.forward(self, x, edge_index, edge_weight)
208 cache = self._cached_edge_index
209 if cache is None:
--> 210 edge_index, edge_weight = gcn_norm( # yapf: disable
211 edge_index, edge_weight, x.size(self.node_dim),
212 self.improved, self.add_self_loops, self.flow, x.dtype)
213 if self.cached:
214 self._cached_edge_index = (edge_index, edge_weight)
File ~\AppData\Roaming\Python\Python38\site-packages\torch_geometric\nn\conv\gcn_conv.py:100, in gcn_norm(edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype)
98 row, col = edge_index[0], edge_index[1]
99 idx = col if flow == 'source_to_target' else row
--> 100 deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
101 deg_inv_sqrt = deg.pow_(-0.5)
102 deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
File ~\AppData\Roaming\Python\Python38\site-packages\torch_geometric\utils\scatter.py:74, in scatter(src, index, dim, dim_size, reduce)
72 if reduce == 'sum' or reduce == 'add':
73 index = broadcast(index, src, dim)
---> 74 return src.new_zeros(size).scatter_add_(dim, index, src)
76 if reduce == 'mean':
77 count = src.new_zeros(dim_size)
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
and additionally, using TORCH_USE_CUDA_DSA
the following error is also displayed on my console:
C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\ScatterGatherKernel.cu:145: block: [73,0,0], thread: [63,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.
What's exactly the problem?
Analyzing the output of the error, it seems that we are reaching out of a tensor's bounds when creating a tensor (i.e. we are passing a size that's to big for our tensors). What I have realized is that since we are changing the way to pass the
x
tensor (we are separating thebatch_size
andnum_nodes
, we should also modify theedge_index
tensor so that it has a shape that makes sense for the proposedx
graph. Assuming thatedge_index
is static (meaning it doesn't change for every element of the batch), we can unbatch it using the method:In this case, we should then add
batch
as a parameter for theforward
method of the model:And also instantiate the
batch
tensor when calling the batch: