I am having troubles with a simple complex vector multiplication code.
from functools import partial
import jax
from jax.experimental import pallas as pl
import jax.numpy as jnp
import numpy as np
def cdot_vectors_kernel(x_real_ref, x_imag_ref, y_real_ref, y_imag_ref, o_ref):
x_real = pl.load(x_real_ref, (slice(None),))
x_imag = pl.load(x_imag_ref, (slice(None),))
y_real = pl.load(y_real_ref, (slice(None),))
y_imag = pl.load(y_imag_ref, (slice(None),))
o_real = x_real * y_real - x_imag * y_imag
o_imag = x_real * y_imag + x_imag * y_real
o = jnp.array(o_real + 1j * o_imag)
pl.store(o_ref, (slice(None), ), o)
@jax.jit
def cdot_vectors(x_real: jax.Array, x_imag: jax.Array, y_real: jax.Array, y_imag: jax.Array) -> jax.Array:
return pl.pallas_call(
cdot_vectors_kernel,
out_shape=jax.ShapeDtypeStruct(x_real.shape, jnp.complex64)
)(x_real, x_imag, y_real, y_imag)
array1 = jnp.array([2+3j, 1-1j, 5+1j])
array2 = jnp.array([1+2j, 1-2j, 2+1j])
cdot_vectors(array1.real, array1.imag, array2.real, array2.imag)
The error I am getting is quite long but this stuck to me:
NotImplementedError: cannot cast Value(%33 = "arith.addf"(%31, %32) <{fastmath = #arith.fastmath<none>}> : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>) to tensor<3xcomplex<f32>>
I think if I were able to have a o_real_ref and o_imag_ref as outputs, that would work but I have seen nothing with a Pallas Kernel that has two outputs. I know how to do this in Triton with a separate function but here I'm not sure.
Any ideas?