How to add a new attribute to a torch_geometric.data Data object element?

258 views Asked by At

I am trying to extend the elements of a TUDataset dataset. In particular, I have a dataset obtained via

dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)

I want to add a new vector-like feature to every entry of the dataset.

for i, current_g in enumerate(dataset):

    nxgraph = nx.to_numpy_array(torch_geometric.utils.to_networkx(current_g) )
    feature = do_something(nxgraph)
    dataset[i].new_feature = feature

However, this code doesn't seem to work. As you can verify yourself, it's not possible to add attributes to an element of dataset.

In [80]: dataset[2].test = 1

In [81]: dataset[2].test
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test

AttributeError: 'Data' object has no attribute 'test'

In [82]: dataset[2].__setattr__('test', 1)

In [83]: dataset[2].test
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
~/workspace/grouptheoretical/new-experiments/HGP-SL-myfork/main.py in <cell line: 1>()
----> 1 dataset[2].test

AttributeError: 'Data' object has no attribute 'test'

An element in dataset is a Data from torch_geometric.data Data.

I can create a new Data element with all the features I want by using:

tmp=dataset[i].to_dict()
tmp['new_feature'] = feature
new_dataset[i]=torch_geometric.data.Data.from_dict(tmp)

However, I don't know how to create a TUDataset dataset (Or the partent class of it) from a list of Data elements. Do you know how?

Any idea on how to solve this problem? Thanks.

2

There are 2 answers

1
Renyi On BEST ANSWER

One elegant way to reach your goal is to define your transformation.

from torch_geometric.transforms import BaseTransform

class Add_Node_Feature(BaseTransform):
    def __init__(self, parameters):
        self.paramters= paramters  # parameters you need
    def __call__(self, data: Data) -> Data:
        node_feature = data.x
        data.x = do_something(node_feature)
        return data

Then, you can apply this transformation when loading the dataset. This way, the dataset is modified, and new features will be added.

import torch_geometric.transforms as T
dataset = TUDataset("PROTEIN", name=PROTEIN, use_node_attr=True)
dataset.transform = T.Compose([Add_Node_Feature()])
0
asdf On

The solution was very easy. In most of the cases you just need a list (and DataLoader works with lists just fine):

dataset = [change_element(element) for element in dataset]

where change_element returns a new Data element as described in the question.