Is there an efficient way of implementing sparsemax in pytorch-geometric?

96 views Asked by At

My implementation of sparsemax in pytorch-geometric is having cuda memory problems and is too slow compared to softmax implementation in

This is my code:

from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.typing import OptTensor


def sparsemax(a: Tensor) -> Tensor:
    zs = torch.sort(a, descending=True, dim=0).values
    size = zs.size(0)
    indices = torch.arange(start=1,
                         end=size + 1,
                         step=1,
                         dtype=int,
                         device=a.device).reshape(size, 1)
    bound = torch.as_tensor(1, device=a.device) + indices * zs
    cum_sum_zs = torch.cumsum(zs, dim=0)
    is_ge = torch.ge(bound, cum_sum_zs)
    k = torch.max(is_ge * indices)
    tau = (cum_sum_zs[k - 1] - torch.as_tensor(1, device=a.device)) / k
    return torch.relu(a - tau)


def sparsemax_pyg(src, index, ptr, size_i) -> Tensor:
    unique_indices = torch.unique(index)
    # print("index foi usado")
    result = torch.zeros_like(src)

    for i in unique_indices:
        mask = index == i
        result[mask] = sparsemax(src[mask])
    return result

This is PyG softmax implementation:

import torch
from torch_scatter import scatter_max, scatter_add

def softmax(src, index, num_nodes=None):
    num_nodes = index.max().item() + 1 if num_nodes is None else num_nodes
    out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
    out = out.exp()
    out = out / (
        scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)
    return out

Example input:

src = torch.as_tensor([[1],[2],[3],[4],[5],[6]], dtype=torch.float64)
index = torch.as_tensor([0,2,1,1,0,2])

I want that sparsemax (instead of softmax) is applied to tensor([[1],[5]]) (index 0), tensor([[3],[4]]) (index 1), tensor([[2],[6]]) (index 2) and then return:

tensor([[0.],
        [0.],
        [0.],
        [1.],
        [1.],
        [1.]])

instead of

tensor([[0.0180],
        [0.0180],
        [0.2689],
        [0.7311],
        [0.9820],
        [0.9820]])

It is calculating well the algorithm, but for my task it makes the training way slower and I have to reduce batch sizes because it displays CUDA out of memory problems.

0

There are 0 answers