Process and raw directory in Pytorch Geometric

39 views Asked by At

I have two graph representations for the same image file. One for local information and another for global.

In pytorch geometric, the processed data is stored in the processed directory. Is it possible to save the two different representations to different processed directories? What would be the best way to save and get the processed data in this case?

I use this for facial feature extraction and classification problems. while the global information is for overall features, local features extract region-specific information. I have tried using a DataList but am not sure of how works in the case of 'process directory' as given in Pytorch geometric.

1

There are 1 answers

1
Serge de Gosson de Varennes On

You did not explain why you wanted to save both the local and global representations. in many cases (imho) it is redundant. But, it may be so that it isn't in your case. Anyhow, I have had cases where this helped me compare the performance of models trained on different representations and gave insightson how different aspects of the data were captured by the models. So kudos to that!.

Here is how I went about doing this (adapt it to how you want to save them)

import os
import torch
from torch_geometric.data import InMemoryDataset

class LocalRepresentationDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(LocalRepresentationDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['local_data.pt']

    def download(self):
        pass

    def process(self):
        data_list = [...] #(PLUGG IN YOU LOCAL REPRESENTATION HERE)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

class GlobalRepresentationDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(GlobalRepresentationDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return ['global_data.pt']

    def download(self):
        pass

    def process(self):
        data_list = [...] #(PLUGG IN YOU GLOBAL REPRESENTATION HERE)
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

local_processed_dir = 'processed/local'
global_processed_dir = 'processed/global'

os.makedirs(local_processed_dir, exist_ok=True)
os.makedirs(global_processed_dir, exist_ok=True)

local_dataset = LocalRepresentationDataset(local_processed_dir)
global_dataset = GlobalRepresentationDataset(global_processed_dir)