I received an error from flax 0.7.5, could u help me:
File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\__init__.py:5
from jVMC.nets.rnn import *
File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\rnn.py:92
class RNN(nn.Module):
File ~\AppData\Roaming\Python\Python311\site-packages\jVMC\nets\rnn.py:138 in RNN
@partial(nn.transforms.scan,
AttributeError: module 'flax.linen' has no attribute 'transforms'
Many thanks!!!
I have tried flax 0.7.4
after I solved an AttributeError: module 'flax' has no attribute 'nn'
From your traceback, it looks like you are using vmc_jax. Looking at this package, it appears to require flax v0.6.4-0.6.11; I suspect the error you're seeing is because you're using too new a flax version for this package.
I would suggest installing flax 0.6.11, and also use the required jax versions listed there: