Import error in Jax

I am getting an import issue while running vit_jax.ipynb on Colab which was working fine 2 days ago. Monitoring from jax is not being imported in checkpoint.py and train.py. The jax version I'm using is 0.3.25. Please let me know regarding how to fix this, Jax_import_error

4 Comments

  1. Same here, it's ironic how both are internally developed still inconsistencies are there between the teams. Code was working before jax release 0.4.2

  2. A quick hotfix is to use an older version of flax, which doesn't have newer Jax imports works for me.
    Since, the vision-transformer doesn't specify a version in setup file, latest version of flax gets downloaded resulting in error.
    Run the below cell before installing vision-transformer

    image

  3. This is now fixed with the new Flax release & #256

    Sorry for the delay on this issue
    (I saw it a bit late, and then it took some time to get Flax updated & make the updated dependencies work in Colab)