DeepExplainer outputs shap_values all zeros in SHAP explaining an AE model

132 views Asked by At

I try to use SHAP to explain the importance of input features towards the difference loss between two latent outputs. There are two inputs (size = (num, 17000) and (num, 4100)) in the original net, the net act as an AE to translate two inputs to themself. A loss is used to make the difference between two latents bigger, and I want to explain this using SHAP. Because SHAP won't recieve two inputs in DeepExplainer, I wrote a copynet class to pack the original net, and along with the SHAP code are as below:


input_X = torch.cat((train_a,train_b),dim=1)

class copynet(nn.Module):
    def __init__(self,net):
        super(copynet,self).__init__()

        self.net0 = net
    
    def forward(self, input_X): #input_X.shape = (batchsize, dim=17000+4100)
        x = input_X[:,:train_a.shape[1]] #17000
        y = input_X[:,train_b.shape[1]:] #4100
        x = torch.tensor(x,dtype=torch.float32,requires_grad=True)
        y = torch.tensor(y,dtype=torch.float32,requires_grad=True)

        _,_,_,_,_,_,_,_,a1_l,b1_l = self.net0(x,y) #a1_l or b1_l.shape = (batchsize, dim=128), they are two latent vector in the net, not the final output (which I do not want to explain)

        return torch.mean(a1_l*b1_l).view(-1,1) #torch.mean.. is the value (a loss function when training the net) I want to explain
    
copynet0 = copynet(net)

explainer = shap.Explainer(copynet0, input_X[5000:5100]) #input_X has already been shaffled

shap_values = explainer.shap_values(input_X[5000:5100])

and the shap_values are all zeros, and I don't know what problem it could be.

I've checked the outupt of the copynet, and they differ a lot between different inputs. Also I've tried changing a single value in a sample to zero, and the loss will change too.

Is there any problems in the code, or one concern by me is that the input data are sparse, with about 95% of zeros. But what confuse me is that the output of copynet obviously differ from each other, with positive and negative values either, how can the shap_values stay all zeros?

0

There are 0 answers