My implementation of sparsemax in pytorch-geometric is having cuda memory problems and is too slow compared to softmax implementation in
This is my code:
from typing import Optional
import torch
from torch import Tensor
import torch.nn.functional as F
from torch_geometric.typing import OptTensor
def sparsemax(a: Tensor) -> Tensor:
zs = torch.sort(a, descending=True, dim=0).values
size = zs.size(0)
indices = torch.arange(start=1,
end=size + 1,
step=1,
dtype=int,
device=a.device).reshape(size, 1)
bound = torch.as_tensor(1, device=a.device) + indices * zs
cum_sum_zs = torch.cumsum(zs, dim=0)
is_ge = torch.ge(bound, cum_sum_zs)
k = torch.max(is_ge * indices)
tau = (cum_sum_zs[k - 1] - torch.as_tensor(1, device=a.device)) / k
return torch.relu(a - tau)
def sparsemax_pyg(src, index, ptr, size_i) -> Tensor:
unique_indices = torch.unique(index)
# print("index foi usado")
result = torch.zeros_like(src)
for i in unique_indices:
mask = index == i
result[mask] = sparsemax(src[mask])
return result
This is PyG softmax implementation:
import torch
from torch_scatter import scatter_max, scatter_add
def softmax(src, index, num_nodes=None):
num_nodes = index.max().item() + 1 if num_nodes is None else num_nodes
out = src - scatter_max(src, index, dim=0, dim_size=num_nodes)[0][index]
out = out.exp()
out = out / (
scatter_add(out, index, dim=0, dim_size=num_nodes)[index] + 1e-16)
return out
Example input:
src = torch.as_tensor([[1],[2],[3],[4],[5],[6]], dtype=torch.float64)
index = torch.as_tensor([0,2,1,1,0,2])
I want that sparsemax (instead of softmax) is applied to tensor([[1],[5]]) (index 0), tensor([[3],[4]]) (index 1), tensor([[2],[6]]) (index 2) and then return:
tensor([[0.],
[0.],
[0.],
[1.],
[1.],
[1.]])
instead of
tensor([[0.0180],
[0.0180],
[0.2689],
[0.7311],
[0.9820],
[0.9820]])
It is calculating well the algorithm, but for my task it makes the training way slower and I have to reduce batch sizes because it displays CUDA out of memory problems.