Add attribute to object of dataset

972 views Asked by At

I am very new to pytorch and pytorch-geometric. I need to load a dataset and then map an attribute to every object of the set that I will use later in the script. However I can't figure out how to do it.

I start the loading as

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='PROTEINS')

then I add the attribute. I tried (the value 3 is only for example, it will be a db query)

for data in dataset:
    data.keys.append('szemeredi_id')
    data.szemeredi_id = 3

or

for data in dataset:
    data['szemeredi_id'] = 3

or

for i, s in enumerate(dataset):
    dataset[i]['szemeredi_id'] = 3

or

for data in dataset:
    setattr(data, 'szemeredi_id', 3)

but that attribute is always empty. I even tried to write a decorator class for the Data class as

class SzeData(Data):
    def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
                 pos=None, normal=None, face=None, **kwargs):
        super(SzeData, self).__init__(x, edge_index, edge_attr, y, pos, normal, face)
        self.szemeredi_id = None

but if I try to replace the Data objects it raises the error TypeError: 'TUDataset' object does not support item assignment or it does nothing if I use this solution.

Any suggestion is much appreciated. Thank you.

1

There are 1 answers

0
Vvvvvv On BEST ANSWER

You can organize your modification process to each sample data as a transform function, and then pass it to the transform or pre_transform(which depends on your need) parameter when constructing the dataset:

from torch_geometric.datasets import TUDataset

def transform(data):
    data.szemeredi_id = 3
    return data

dataset = TUDataset(root='data/TUDataset', name='PROTEINS', transform=transform)
# or dataset = TUDataset(root='data/TUDataset', name='PROTEINS', pre_transform=transform)

See the documentation of torch_geometric.data.Dataset

  • transform (callable, optional) – A function/transform that takes in an Data object and returns a transformed version. The data object will be transformed before every access. (default: None)
  • pre_transform (callable, optional) – A function/transform that takes in an Data object and returns a transformed version. The data object will be transformed before being saved to disk. (default: None)

Edit:

The above method is unaware of the data index in the dataset, so if you want to add some index-related attributes, it won't help.

To add an index-related attribute (e.g. simply index), I use the less elegant but more general approach as follows:

from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='PROTEINS')

def add_attributes(dataset):
    data_list = []
    for i, data in enumerate(dataset):
        data.index = i
        data_list.append(data)
    dataset.data, dataset.slices = dataset.collate(data_list)
    return dataset

dataset = add_attributes(dataset)