PyG dataset showing more than 1 graph

19 views Asked by At

I am a newbie to PyG and attempting to build a PyG dataset from a small json file (with 5 records: 5 nodes, 8 edges). After building the dataset, when I print out the properties of the graph, I see that the number of graphs is 3 and number of nodes is 20. I expect only 5 nodes and only 1 graph. The number of edges is right (8). Perhaps because even though there are 5 nodes, there are 4 types of nodes (org, event, player and rated).

I am not sure where I am making a mistake in creating this dataset. Please note that at this point I am trying to learn to create the correct PyG dataset for a given input. I have not thought about any node-classification, link-prediction or anomaly-detection scenario yet.

The 'rated' field is assumed to be the label (y).

Graph Dataset:
HeteroData(
  org={ x=[5, 2] },
  player={ x=[5, 3] },
  event={ x=[5, 1] },
  rated={ x=[5, 1] },
  (event, is_related_to, event)={ edge_index=[2, 8] },
  (player, is_rated, rated)={ y=[5] }
)
Number of graphs: 3
Number of nodes: 20
Number of edges: 8
Number of node-features: {'org': 2, 'player': 3, 'event': 1, 'rated': 1}
Number of edge-features: {('event', 'is_related_to', 'event'): 0, ('player', 'is_rated', 'rated'): 0}
Edges are directed: True
Graph has isolated nodes: True
Graph has loops: False
Node Types: ['org', 'player', 'event', 'rated']
Edge Attributes: 20

The code for building the dataset looks like this:

def build_dataset(self, edge_index, org_X, player_X, event_X, rated_X, labels_y):
    data = HeteroData()
    data['org'].x = org_X
    data['player'].x = player_X
    data['event'].x = event_X
    data['rated'].x = rated_X
    data['event', 'is_related_to', 'event'].edge_index = edge_index
    data['player', 'is_rated', 'rated'].y = labels_y
    return data

I convert a node to its features is like this (for player-node):

def extract_player_node_features(self, df):
    sorted_player_df = df.sort_values(by='player_id').set_index('player_id')
    sorted_player_df = sorted_player_df.reset_index(drop=False)
    player_id_mapping = sorted_player_df['player_id']
    #print(f'\nPlayer ID mapping:\n{player_id_mapping}')

    # select player node features
    player_node_features_df = df[['player_name', 'age', 'school']]

    player_name_features_df = pd.DataFrame(player_node_features_df.player_name.values.tolist(), player_node_features_df.index).add_prefix('player_name_')
    player_name_features_ohe = pd.get_dummies(player_name_features_df)

    player_age_features_df = pd.DataFrame(player_node_features_df.age.values.tolist(), player_node_features_df.index).add_prefix('age_')
    player_age_features_ohe = pd.get_dummies(player_age_features_df)

    player_school_features_df = pd.DataFrame(player_node_features_df.school.values.tolist(), player_node_features_df.index).add_prefix('school_')
    player_school_features_ohe = pd.get_dummies(player_school_features_df)

    player_node_features = pd.concat([player_node_features_df, player_name_features_ohe], axis=1)
    player_node_features = pd.concat([player_node_features, player_age_features_ohe], axis=1)
    player_node_features = pd.concat([player_node_features, player_school_features_ohe], axis=1)

    player_node_features.drop(columns=['player_name', 'age', 'school'], axis=1, inplace=True)
    player_node_X = player_node_features.to_numpy(dtype='int32')
    player_node_X = torch.from_numpy(player_node_X)
    return player_node_X

Finally, the df (created from json input file) and the converted-df (to numeric, startng from 0 to make it compact) are below:

Input df:

  event_id                   event_type  org_id                        org_name org_location    player_id player_name  age school related_event_id rated
0    1-ab3                   qualifiers     305               milan tennis club        Milan  1-b7a3-52d2        Alex   20    BCE   [4-ab3, 3-ab3]    no
1    2-ab3              under 18 finals      76            Nadal tennis academy       madrid  2-b7a3-52d2         Bob   20   BCMS   [5-ab3, 1-ab3]   yes
2    3-ab3     womens tennis qualifiers     185                Griz tennis club     budapest  3-b7a3-52d2        Mary   21    BCE          [4-ab3]    no
3    4-ab3  US professional tennis club     285  Nick Bolletieri Tennis Academy        tampa  4-b7a3-52d2         Joe   21   BCMS   [1-ab3, 3-ab3]   yes
4    5-ab3        womens tennis circuit     305               milan tennis club        Milan  5-b7a3-52d2        Bolt   22   LTHS          [4-ab3]    no

Sorted input df:

  related_event_id  org_id  org_name  org_location  player_id  player_name  event_id  event_type  age  school  rated
1           [4, 0]       0         1             2          1            1         1           2    0       1      1
2              [3]       1         0             1          2            4         2           4    1       0      0
3           [0, 2]       2         2             3          3            3         3           0    1       1      1
0           [3, 2]       3         3             0          0            0         0           1    0       0      0
4              [3]       3         3             0          4            2         4           3    2       2      0

When I ignore 'rated' as a node-type, validate() fn complains. So, I half-heartedly added 'rated' to node-types.

ValueError: The node types {'rated'} are referenced in edge types but do not exist as node types

Any help/suggestion to why I am seeing what I am seeing here? Will be happy to provide code/input-file for you to reproduce.

0

There are 0 answers