I'm trying to run a Colab notebook for image generation with JAX and ran into the following error:
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-7-73b0723cc3af> in <cell line: 23>()
21 import jax.numpy as jnp
22 import jax.scipy as jsp
---> 23 import jaxtorch
24 from jaxtorch import PRNG, Context, Module, nn, init
25 from tqdm import tqdm
3 frames
/content/./jax-guided-diffusion/jaxtorch/monkeypatches.py in register(**kwargs)
16 print(f'Not monkeypatching DeviceArray and Tracer with `{attr}`, because that method is already implemented.', file=sys.stderr)
17 continue
---> 18 setattr(jaxlib.xla_extension.DeviceArrayBase, attr, fun)
19 setattr(jax.interpreters.xla.DeviceArray, attr, fun)
20 setattr(jax.core.Tracer, attr, fun)
AttributeError: module 'jaxlib.xla_extension' has no attribute 'DeviceArrayBase'
I tried to solve this problem by using different JAX versions and every GPU Colab offers but couldn't find a solution. I'd really appricate any help on this!
Link to the nootebook---> click
DeviceArrayand related types were deprecated and removed in JAX v0.4.1 (See the Changelog). It looks like the version ofjaxtorchyou are using is not compatible with more recent JAX versions. If there is no newer version ofjaxtorchavailable, I would suggest trying to use it with JAX version 0.3.25 or older.