Complex Vector Multiplication Pallas Jax

47 views Asked by At

I am having troubles with a simple complex vector multiplication code.

from functools import partial

import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np


def cdot_vectors_kernel(x_real_ref, x_imag_ref, y_real_ref, y_imag_ref, o_ref):
    x_real = pl.load(x_real_ref, (slice(None),))
    x_imag = pl.load(x_imag_ref, (slice(None),))
    y_real = pl.load(y_real_ref, (slice(None),))
    y_imag = pl.load(y_imag_ref, (slice(None),))
    
    o_real = x_real * y_real - x_imag * y_imag
    o_imag = x_real * y_imag + x_imag * y_real

    o = jnp.array(o_real + 1j * o_imag)   
    pl.store(o_ref, (slice(None), ), o)


@jax.jit
def cdot_vectors(x_real: jax.Array, x_imag: jax.Array, y_real: jax.Array, y_imag: jax.Array) -> jax.Array:
    return pl.pallas_call(
        cdot_vectors_kernel,
        out_shape=jax.ShapeDtypeStruct(x_real.shape, jnp.complex64)
    )(x_real, x_imag, y_real, y_imag)

array1 = jnp.array([2+3j, 1-1j, 5+1j])
array2 = jnp.array([1+2j, 1-2j, 2+1j])

cdot_vectors(array1.real, array1.imag, array2.real, array2.imag)

The error I am getting is quite long but this stuck to me:

NotImplementedError: cannot cast Value(%33 = "arith.addf"(%31, %32) <{fastmath = #arith.fastmath<none>}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>) to tensor<3xcomplex<f32>>

I think if I were able to have a o_real_ref and o_imag_ref as outputs, that would work but I have seen nothing with a Pallas Kernel that has two outputs. I know how to do this in Triton with a separate function but here I'm not sure.

Any ideas?

0

There are 0 answers