I am very new to pytorch and pytorch-geometric. I need to load a dataset and then map an attribute to every object of the set that I will use later in the script. However I can't figure out how to do it.
I start the loading as
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='data/TUDataset', name='PROTEINS')
then I add the attribute. I tried (the value 3 is only for example, it will be a db query)
for data in dataset:
data.keys.append('szemeredi_id')
data.szemeredi_id = 3
or
for data in dataset:
data['szemeredi_id'] = 3
or
for i, s in enumerate(dataset):
dataset[i]['szemeredi_id'] = 3
or
for data in dataset:
setattr(data, 'szemeredi_id', 3)
but that attribute is always empty. I even tried to write a decorator class for the Data class as
class SzeData(Data):
def __init__(self, x=None, edge_index=None, edge_attr=None, y=None,
pos=None, normal=None, face=None, **kwargs):
super(SzeData, self).__init__(x, edge_index, edge_attr, y, pos, normal, face)
self.szemeredi_id = None
but if I try to replace the Data objects it raises the error TypeError: 'TUDataset' object does not support item assignment
or it does nothing if I use this solution.
Any suggestion is much appreciated. Thank you.
You can organize your modification process to each sample
data
as a transform function, and then pass it to thetransform
orpre_transform
(which depends on your need) parameter when constructing the dataset:See the documentation of
torch_geometric.data.Dataset
Edit:
The above method is unaware of the
data
index in thedataset
, so if you want to add some index-related attributes, it won't help.To add an index-related attribute (e.g. simply
index
), I use the less elegant but more general approach as follows: