Seems the code relies on a function that changed in the newer Jax version.
Python 3.10.4 (main, Mar 31 2022, 08:41:55) [GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import jax
>>> jax.__version__
'0.3.25'
>>> import quax
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/root/psi4conda/lib/python3.10/site-packages/quax-0.1.1-py3.10.egg/quax/__init__.py", line 1, in <module>
File "/root/psi4conda/lib/python3.10/site-packages/quax-0.1.1-py3.10.egg/quax/integrals/__init__.py", line 2, in <module>
File "/root/psi4conda/lib/python3.10/site-packages/quax-0.1.1-py3.10.egg/quax/integrals/integrals_utils.py", line 6, in <module>
ImportError: cannot import name 'loops' from 'jax.experimental' (/root/psi4conda/lib/python3.10/site-packages/jax/experimental/__init__.py)
Seems the code relies on a function that changed in the newer Jax version.
Question 1. Which version of Jax did you used during development?
The file environment.yml just specifies jax and jaxlib. It does not state a specific version.