AttributeError: module 'jax.tree_util' has no attribute 'PyTreeDef' when import haiku

584 views Asked by At

I'm trying to use the Jax library with haiku on python3.6 at conda env, I met the below error and am stuck. I have tried to update my Jax version but nothing changed how can I fix it?

Traceback (most recent call last):
  File "train.py", line 14, in <module>
    import haiku as hk
  File "/home/young/.local/lib/python3.6/site-packages/haiku/__init__.py", line 19, in <module>
    from haiku import data_structures
  File "/home/young/.local/lib/python3.6/site-packages/haiku/data_structures.py", line 18, in <module>
    from haiku._src.data_structures import to_haiku_dict
  File "/home/young/.local/lib/python3.6/site-packages/haiku/_src/data_structures.py", line 176, in <module>
    class FlatComponents(NamedTuple):
  File "/home/young/.local/lib/python3.6/site-packages/haiku/_src/data_structures.py", line 178, in FlatComponents
    structure: jax.tree_util.PyTreeDef
AttributeError: module 'jax.tree_util' has no attribute 'PyTreeDef'
1

There are 1 answers

0
jakevdp On

jax.tree_util.PyTreeDef didn't exist prior to JAX version 0.2.22, which was released in October 2021. If you're getting this error, then you probably need to update your JAX installation to a newer version.

That said, you mentioned Python 3.6: JAX version 0.2.18 and newer requires Python 3.7 or later. So if you must use Python 3.6, you cannot use JAX 0.2.22 and you'll have to install an older version of haiku that is compatible with JAX version 0.2.18. Haiku 0.0.4 looks like the last release that was compatible with Python 3.6, and is from the same era as JAX 0.2.18.

Overall, though, you'll have a much better experience if you can update your Python installation; Python 3.6 reached its end of life in December 2021 and most packages will no longer work with it.