How to set up A3TGCN2 module using batches?

59 views Asked by At

I'm trying to use a A3TGCN2 for traffic prediction, but I've had a hard time to understand how to set it up. I thought that the way of defining this module would follow the same logic as other modules from this library, such as GConvLSTM, but apparently the A3TGCN2 module doesn't accept batched edge_index tensors.

The documentation states that this implementation specifically allows for the use of batched tensors, but when I try to run it with batched tensors, it simply breaks my code:

class MyModel(nn.Module):
    def __init__(self,
                 features: int,
                 out_dim: int,
                 batch_size: int,
                 periods: int,
                 device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
                ):
        super(Predictor, self).__init__()
        self.features= features
        self.out_dim= out_dim
        self.batch_size = batch_size
        self.device = device
        self.periods = periods
        
        self.tgnn = A3TGCN2(
            in_channels=self.features,
            out_channels=self.out_dim,
            periods=self.periods,
            batch_size=self.batch_size,
        )
        
    def forward(self, x, edge_index):
        '''
        Parameters
        ------------
        x: torch.Tensor
            node features, of shape (seq_len, num_nodes*batch_size, features)
        edge_index: torch.Tensor
            edge_indices, of shape (2, batch_size*num_edges)
        '''

        seq_len, num_nodes, features= x.shape
        x = torch.movedim(x, 0, -1)
        x = x.reshape(self.batch_size, -1, features, seq_len)
        
        # now x is shaped (batch_size, num_nodes, features, seq_len)


        H = self.tgnn(X=x, edge_index=edge_index)

        return H

The documentation states that x must be of shape (batch_size, num_nodes, features, seq_len), but this doesn't seem to work properly, as I'm yielding the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[49], line 80
     78     features = encoder(x, edge_index)
     79     features = GRLayer(features)
---> 80     y_pred   = pred(features, edge_index, batch)
     83     loss = criterion(y_pred, y)
     84 total_loss += loss

File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

Cell In[48], line 32, in Predictor.forward(self, x, edge_index, batch)
     29 print(x.shape)
     31 # edge_index = self.unbatch_edge_index(edge_index, batch)
---> 32 H = self.tgnn(X=x, edge_index=edge_index)
     33 print(H.shape)
     35 return H

File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\t4c\lib\site-packages\torch_geometric_temporal\nn\recurrent\attentiontemporalgcn.py:155, in A3TGCN2.forward(self, X, edge_index, edge_weight, H)
    152 probs = torch.nn.functional.softmax(self._attention, dim=0)
    153 for period in range(self.periods):
--> 155     H_accum = H_accum + probs[period] * self._base_tgcn( X[:, :, :, period], edge_index, edge_weight, H) #([32, 207, 32]
    157 return H_accum

File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~\anaconda3\envs\t4c\lib\site-packages\torch_geometric_temporal\nn\recurrent\temporalgcn.py:229, in TGCN2.forward(self, X, edge_index, edge_weight, H)
    214 """
    215 Making a forward pass. If edge weights are not present the forward pass
    216 defaults to an unweighted graph. If the hidden state matrix is not present
   (...)
    226     * **H** *(PyTorch Float Tensor)* - Hidden state matrix for all nodes.
    227 """
    228 H = self._set_hidden_state(X, H)
--> 229 Z = self._calculate_update_gate(X, edge_index, edge_weight, H)
    230 R = self._calculate_reset_gate(X, edge_index, edge_weight, H)
    231 H_tilde = self._calculate_candidate_state(X, edge_index, edge_weight, H, R)

File ~\anaconda3\envs\t4c\lib\site-packages\torch_geometric_temporal\nn\recurrent\temporalgcn.py:188, in TGCN2._calculate_update_gate(self, X, edge_index, edge_weight, H)
    187 def _calculate_update_gate(self, X, edge_index, edge_weight, H):
--> 188     Z = torch.cat([self.conv_z(X, edge_index, edge_weight), H], axis=2) # (b, 207, 64)
    189     Z = self.linear_z(Z) # (b, 207, 32)
    190     Z = torch.sigmoid(Z)

File ~\anaconda3\envs\t4c\lib\site-packages\torch\nn\modules\module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~\AppData\Roaming\Python\Python38\site-packages\torch_geometric\nn\conv\gcn_conv.py:210, in GCNConv.forward(self, x, edge_index, edge_weight)
    208 cache = self._cached_edge_index
    209 if cache is None:
--> 210     edge_index, edge_weight = gcn_norm(  # yapf: disable
    211         edge_index, edge_weight, x.size(self.node_dim),
    212         self.improved, self.add_self_loops, self.flow, x.dtype)
    213     if self.cached:
    214         self._cached_edge_index = (edge_index, edge_weight)

File ~\AppData\Roaming\Python\Python38\site-packages\torch_geometric\nn\conv\gcn_conv.py:100, in gcn_norm(edge_index, edge_weight, num_nodes, improved, add_self_loops, flow, dtype)
     98 row, col = edge_index[0], edge_index[1]
     99 idx = col if flow == 'source_to_target' else row
--> 100 deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum')
    101 deg_inv_sqrt = deg.pow_(-0.5)
    102 deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)

File ~\AppData\Roaming\Python\Python38\site-packages\torch_geometric\utils\scatter.py:74, in scatter(src, index, dim, dim_size, reduce)
     72 if reduce == 'sum' or reduce == 'add':
     73     index = broadcast(index, src, dim)
---> 74     return src.new_zeros(size).scatter_add_(dim, index, src)
     76 if reduce == 'mean':
     77     count = src.new_zeros(dim_size)

RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

and additionally, using TORCH_USE_CUDA_DSA the following error is also displayed on my console:

C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\aten\src\ATen\native\cuda\ScatterGatherKernel.cu:145: block: [73,0,0], thread: [63,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "index out of bounds"` failed.

What's exactly the problem?

1

There are 1 answers

0
olenscki On

Analyzing the output of the error, it seems that we are reaching out of a tensor's bounds when creating a tensor (i.e. we are passing a size that's to big for our tensors). What I have realized is that since we are changing the way to pass the x tensor (we are separating the batch_size and num_nodes, we should also modify the edge_index tensor so that it has a shape that makes sense for the proposed x graph. Assuming that edge_index is static (meaning it doesn't change for every element of the batch), we can unbatch it using the method:

    def unbatch_edge_index(self, edge_index, batch):
        # Calculate the number of nodes in each graph
        num_nodes_per_graph = torch.bincount(batch)

        # Calculate the cumulative sum of nodes to determine the boundaries
        cum_nodes = torch.cumsum(num_nodes_per_graph, dim=0)
        cum_nodes = torch.cat([torch.tensor([0], device=self.device), cum_nodes])

        # Split the edge_index for each graph
        edge_indices = []
        mask = (edge_index[0] >= cum_nodes[0]) & (edge_index[0] < cum_nodes[1])
        edge_subset = edge_index[:, mask]
        # Adjust node indices to start from 0 for each graph
        edge_subset[0] -= cum_nodes[0]
        edge_subset[1] -= cum_nodes[0]

        return edge_subset

In this case, we should then add batch as a parameter for the forward method of the model:

    def forward(self, x, edge_index, batch):
        '''
        Parameters
        ------------
        x: torch.Tensor
            node features, of shape (seq_len, num_nodes*batch_size, features)
        edge_index: torch.Tensor
            edge_indices, of shape (2, batch_size*num_edges)
        batch: torch.Tensor
            batch tensor (tensor that delimits the edge indices of each batch),
            of shape (batch_size*num_edges)
        '''

        seq_len, num_nodes, features= x.shape
        x = torch.movedim(x, 0, -1)
        x = x.reshape(self.batch_size, -1, features, seq_len)
        
        # now x is shaped (batch_size, num_nodes, features, seq_len)

        # now we unbatch the edge_index tensor
        edge_index = self.unbatch_edge_index(edge_index, batch)

        H = self.tgnn(X=x, edge_index=edge_index)

        return H

And also instantiate the batch tensor when calling the batch:

for databatch, i, _ in source_dataloader:
    for city, data in databatch.items():
        x, edge_index, y, batch = data.x, data.edge_index, data.y, data.batch
        ...