How to calculate the median of a masked tensor along an axis?

1.6k views Asked by At

I have tensor X of floats of dimensions n x m and a tensor Y of booleans of dimensions n x m. I want to calculate values such as the mean, median and max of X, along one of the axes, but only considering the values in X which are true in Y. Something like X[Y].mean(dim=1). This is not possible because X[Y] is always a 1D tensor.

Edit:

For the mean, I was able to do it with:

masked_X = X * Y
masked_X_mean = masked_X.sum(dim=1) / Y.sum(dim=1)

For the max:

masked_X = X
masked_X[Y] = float('-inf')
masked_X_max = masked_X.max(dim=1)

But for the median, I was not able to be as creative. Any suggestions??

e.g.

X = torch.tensor([[1, 1, 1],
                  [2, 2, 4]]).type(torch.float32)
Y = torch.tensor([[0, 1, 0],
                  [1, 0, 1]]).type(torch.bool)

Expected Output

mean = [1., 3.]
median = [1., 2.]
var = [0., 1.]
3

There are 3 answers

0
Janosh On

You can piggyback on torch.nanmedian:

def masked_median(x, mask, dim=0):
    """Compute the median of tensor x along dim, ignoring values where mask is False.
    x and mask need to be broadcastable.

    Args:
        x (Tensor): Tensor to compute median of.
        mask (BoolTensor): Same shape as x with True where x is valid and False
            where x should be masked. Mask should not be all False in any column of
            dimension dim to avoid NaNs from zero division.
        dim (int, optional): Dimension to take median of. Defaults to 0.

    Returns:
        Tensor: Same shape as x, except dimension dim reduced.
    """
    # uncomment this assert for safety but might impact performance
    # assert (
    #     mask.sum(dim=dim).ne(0).all()
    # ), "mask should not be all False in any column, causes zero division"
    x_nan = x.float().masked_fill(~mask, float("nan"))
    x_median, _ = x_nan.nanmedian(dim=dim)
    return x_median


X = torch.tensor([[1, 1, 1], [2, 2, 4]])
Y = torch.tensor([[0, 1, 0], [1, 0, 1]]).bool()

masked_median(X, Y, dim=1)
>>> tensor([1., 2.])

Signature with type hints:

def masked_median(x: torch.Tensor, mask: torch.BoolTensor, dim: int = 0) -> torch.Tensor:
    ...

Btw, same approach works for masked_var():

def masked_var(x, mask, dim=0):
    x_nan = x.float().masked_fill(~mask, float("nan"))

    nan_mean = x_nan.nanmean(dim=dim, keepdim=True)
    squared_diff = (x_nan - nan_mean) ** 2
    var = squared_diff.nanmean(dim=dim)
    return var

masked_var(X, Y, dim=1)
>>> tensor([0., 1.])
0
Mohsin hasan On

This is the best I have so far on this:

outs = []
for x, y in zip(X, Y):  # X, Y could be permuted to loop over desired axis
    out = torch.median(torch.masked_select(x, y))
    outs.append(out)
torch.tensor(outs)

Would really appreciate if someone has better solution.

1
Mughees On

Max and median:

As one of the tensor is boolean, it will be great to do elementwise multiplication of original and mask and then just calculate max/median like this.

array = torch.randint(10, (4,4))
mask = torch.randint(2, (4,4)) #it will just generate the [0,1] values]
sol_max = torch.max(array*mask)
sol_median = torch.median(array*mask)