Is NN just Bad at Solving this Simple Linear Problem, or is it because of Bad Training?

113 views Asked by At

I was trying to train a very straightforward (I thought) NN model with PyTorch and skorch, but the bad performance really baffles me, so it would be great if you have any insight into this.

The problem is something like this: there are five objects, A, B, C, D, E, (labeled by their fingerprint, e.g.(0, 0) is A, (0.2, 0.5) is B, etc) each correspond to a number, and the problem is trying to find what number does each correspond to. The training data is a list of "collections" and the corresponding sum. for example: [A, A, A, B, B] == [(0,0), (0,0), (0,0), (0.2,0.5), (0.2, 0.5)] --> 15, [B, C, D, E] == [(0.2,0.5), (0.5,0.8), (0.3,0.9), (1,1)] --> 30 .... Note that number of object in one collection is not constant

There is no noise or anything, so it's just a linear system that can be solved directly. So I would thought this would be very easy for a NN for find out. I'm actually using this example as a sanity check for a more complicated problem, but was surprised that NN couldn't even solve this.

Now I'm just trying to pinpoint exactly where it went wrong. The model definition seem to be right, the data input is right, is the bad performance due to bad training? or is NN just bad at these things?

here is the model definition:

class NN(nn.Module):
    def __init__(
        self,
        input_dim,
        num_nodes,
        num_layers,
        batchnorm=False,
        activation=Tanh,
    ):
        super(SingleNN, self).__init__()
        self.get_forces = get_forces
        self.activation_fn = activation

        self.model = MLP(
            n_input_nodes=input_dim,
            n_layers=num_layers,
            n_hidden_size=num_nodes,
            activation=activation,
            batchnorm=batchnorm,
        )

    def forward(self, batch):
        if isinstance(batch, list):
            batch = batch[0]
        with torch.enable_grad():
            fingerprints = batch.fingerprint.float()
            fingerprints.requires_grad = True
            #index of the current "collection" in the training list
            idx = batch.idx
            sorted_idx = torch.unique_consecutive(idx)
            o = self.model(fingerprints)
            total = scatter(o, idx, dim=0)[sorted_idx]

            return total

    @property
    def num_params(self):
        return sum(p.numel() for p in self.parameters())

class MLP(nn.Module):
    def __init__(
        self,
        n_input_nodes,
        n_layers,
        n_hidden_size,
        activation,
        batchnorm,
        n_output_nodes=1,
    ):
        super(MLP, self).__init__()
        if isinstance(n_hidden_size, int):
            n_hidden_size = [n_hidden_size] * (n_layers)
        self.n_neurons = [n_input_nodes] + n_hidden_size + [n_output_nodes]
        self.activation = activation
        layers = []
        for _ in range(n_layers - 1):
            layers.append(nn.Linear(self.n_neurons[_], self.n_neurons[_ + 1]))
            layers.append(activation())
            if batchnorm:
                layers.append(nn.BatchNorm1d(self.n_neurons[_ + 1]))
        layers.append(nn.Linear(self.n_neurons[-2], self.n_neurons[-1]))
        self.model_net = nn.Sequential(*layers)

    def forward(self, inputs):
        return self.model_net(inputs)

and the skorch part is straightforward

model = NN(2, 100, 2)
net = NeuralNetRegressor(
        module=model,
        ...
    )
net.fit(train_dataset, None)

For a test run, the dataset looks like the following (16 collections in total):

[[0.7484336 0.5656401]
 [0.        0.       ]
 [0.        0.       ]
 [0.        0.       ]]
[[1. 1.]
 [0. 0.]
 [0. 0.]]
[[0.51311415 0.67012525]
 [0.51311415 0.67012525]
 [0.         0.        ]
 [0.         0.        ]]
[[0.51311415 0.67012525]
 [0.7484336  0.5656401 ]
 [0.         0.        ]]
[[0.51311415 0.67012525]
 [1.         1.        ]
 [0.         0.        ]
 [0.         0.        ]]
[[0.51311415 0.67012525]
 [0.51311415 0.67012525]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]]
[[0.51311415 0.67012525]
 [1.         1.        ]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]
 [0.         0.        ]]
....

with corresponding total: [10, 11, 14, 14, 17, 18, ...]

It's easy to tell what are the objects/how many of them are in one collection just by eyeballing it and the training process looks like:

 epoch    train_energy_mae    train_loss    cp     dur
-------  ------------------  ------------  ----  ------
      1              4.9852        0.5425     +  0.1486
      2             16.3659        4.2273        0.0382
      3              6.6945        0.7403        0.0025
      4              7.9199        1.2694        0.0024
      5             12.0389        2.4982        0.0024
      6              9.9942        1.8391        0.0024
      7              5.6733        0.7528        0.0024
      8              5.7007        0.5166        0.0024
      9              7.8929        1.0641        0.0024
     10              9.2560        1.4663        0.0024
     11              8.5545        1.2562        0.0024
     12              6.7690        0.7589        0.0024
     13              5.3769        0.4806        0.0024
     14              5.1117        0.6009        0.0024
     15              6.2685        0.8831        0.0024
....
    290              5.1899        0.4750        0.0024
    291              5.1899        0.4750        0.0024
    292              5.1899        0.4750        0.0024
    293              5.1899        0.4750        0.0024
    294              5.1899        0.4750        0.0025
    295              5.1899        0.4750        0.0025
    296              5.1899        0.4750        0.0025
    297              5.1899        0.4750        0.0025
    298              5.1899        0.4750        0.0025
    299              5.1899        0.4750        0.0025
    300              5.1899        0.4750        0.0025
    301              5.1899        0.4750        0.0024
    302              5.1899        0.4750        0.0025
    303              5.1899        0.4750        0.0024
    304              5.1899        0.4750        0.0024
    305              5.1899        0.4750        0.0025
    306              5.1899        0.4750        0.0024
    307              5.1899        0.4750        0.0025

You can see that it just stopped training after a while. I can confirm that the NN does give different result for different fingerprint, but somehow the final predicted value is just never good enough.

I have tried different NN size, learning rate, batch size, activation function (tanh, relu, etc) and non of them seem to help. Do you have any insight into this? is there anything I did wrong/could try, or is NN just bad at this kind of task?

1

There are 1 answers

0
abe On

First thing I've noticed: super(SingleNN, self).__init__() should be super(NN, self).__init__() instead. Change that and let me know if you still get any errors.