My ICNN doesn't seem to work for any n_hidden

16 views Asked by At

`

class convexLinear(nn.Module):
    def __init__(self, size_in, size_out):
        super().__init__()
        self.size_in, self.size_out = size_in, size_out
        weights = torch.Tensor(size_out, size_in)
        self.weights = nn.Parameter(weights)
        nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))

    def forward(self, x):
        w_times_x= torch.mm(x, F.softplus(self.weights.t()))
        return w_times_x
class ICNN(nn.Module):
    def __init__(self, n_input, n_hidden, n_output):
        super(ICNN, self).__init__()
        self.layers = nn.ModuleDict()
        self.depth = len(n_hidden)
        self.layers[str(0)] = nn.Linear(n_input, n_hidden[0]).float()
        nn.init.xavier_uniform_(self.layers[str(0)].weight)
        # Create create NN with number of elements in n_hidden as depth
        for i in range(1, self.depth):
            self.layers[str(i)] = convexLinear(n_hidden[i-1], n_hidden[i]).float()

        self.layers[str(self.depth)] = convexLinear(n_hidden[self.depth-1], n_output).float()
        
    def forward(self, x):
        # First layer
        x = x.view(-1, 3, 3)
        det = torch.det(x)
        det = det.view(-1, 1)
        x_t = x.transpose(1, 2)
        mult = torch.bmm(x_t, x)
        trace = torch.diagonal(mult, dim1=1, dim2=2).sum(1)
        trace = trace.view(-1, 1)
        x = torch.cat((trace, det), 1)
        z = x.clone()
        z = self.layers[str(0)](z)

        for layer in range(1, self.depth):
            z = self.layers[str(layer)](z)
            z = F.softplus(z)
            z = torch.square(z)
        y = self.layers[str(self.depth)](z)
        return y
n_input = 2
n_output = 1
n_hidden = [64, 64]
icnn = ICNN(n_input, n_hidden, n_output)

x is a 600, 9 data set which gets converted to a 600, 2 set in def forward: Output is supposed to be 600, 1

But the code gives very bad results for any other combination of n_hidden.

In case of any doubts - I need a convex, non-decreasing activation with non-negative weights for the linear layers

0

There are 0 answers