I'm trying to use the Jax library with haiku on python3.6 at conda env, I met the below error and am stuck. I have tried to update my Jax version but nothing changed how can I fix it?
Traceback (most recent call last):
File "train.py", line 14, in <module>
import haiku as hk
File "/home/young/.local/lib/python3.6/site-packages/haiku/__init__.py", line 19, in <module>
from haiku import data_structures
File "/home/young/.local/lib/python3.6/site-packages/haiku/data_structures.py", line 18, in <module>
from haiku._src.data_structures import to_haiku_dict
File "/home/young/.local/lib/python3.6/site-packages/haiku/_src/data_structures.py", line 176, in <module>
class FlatComponents(NamedTuple):
File "/home/young/.local/lib/python3.6/site-packages/haiku/_src/data_structures.py", line 178, in FlatComponents
structure: jax.tree_util.PyTreeDef
AttributeError: module 'jax.tree_util' has no attribute 'PyTreeDef'
jax.tree_util.PyTreeDef
didn't exist prior to JAX version 0.2.22, which was released in October 2021. If you're getting this error, then you probably need to update your JAX installation to a newer version.That said, you mentioned Python 3.6: JAX version 0.2.18 and newer requires Python 3.7 or later. So if you must use Python 3.6, you cannot use JAX 0.2.22 and you'll have to install an older version of haiku that is compatible with JAX version 0.2.18. Haiku 0.0.4 looks like the last release that was compatible with Python 3.6, and is from the same era as JAX 0.2.18.
Overall, though, you'll have a much better experience if you can update your Python installation; Python 3.6 reached its end of life in December 2021 and most packages will no longer work with it.