multiplying conv layer weights( N,C,H,W) with Logits (H,W) pytorch

138 views Asked by At

I am implementing the paper Deep multiscale convolutional feature learning for weakly supervised localization of chest pathologies in X-ray images According to my understanding the layer relevance weights belong to the last layer of each dense block.

I tried implementing the weight constraints as shown below:

 def weight_constraints(self):

        weights= {'feat1': self.model.features.denseblock2.denselayer12.conv2.weight.data,
            'feat2':self.model.features.denseblock3.denselayer24.conv2.weight.data,
            'feat3':self.model.features.denseblock4.denselayer16.conv2.weight.data}

        sum(weights.values()) == 1

        for i in weights.keys():
            w = weights[i]    
            w1 = w.clamp(min= 0)
            weights[i] = w1
        return weights


 weights= self.weight_constraints()
        for i in weights.keys():
            w = weights[i]
            l = logits[i]
            p = torch.matmul(w , l[0])
            sum = sum + p 

where logits is a dictionary which contains out of FC layer from each block as shown in the diagram.

logits = {'feat1': [tensor([[-0.0630]], ...ackward0>)], 'feat2': [tensor([[-0.0323]], ...ackward0>)], 'feat3': [tensor([[-8.2897e-06...ackward0>)]}

I get the following error :

mat1 and mat2 shapes cannot be multiplied (12288x3 and 1x1)

Is this the right approach?

1

There are 1 answers

2
Bob On

The paper states

The logit response from all the layers have same dimension (equal to the number of category for classification) and now can be combined using class specific convex combination to obtain the probability score for the class pc.

The function matmul you used perfroms matrix multiplications, it requires mat1.shape[-1] == mat2.shape[-2].

If you assume sum(w)==1, and torch.all(w > 0), you could compute the convex combination of l as (w * l).sum(-1) that is multiply w and l element-wise, broadcasting over the batch dimensions of l, and requiring w.shape[-1] == l.shape[-1] (presumably 3).

If you want to stick with matmul you can add one dimension to w and l, and perform the vector product as a matrix multiplication: torch.matmul(w[...,None,:], l[..., :, None]).