I am getting an import issue while running vit_jax.ipynb on Colab which was working fine 2 days ago. Monitoring from jax is not being imported in checkpoint.py and train.py. The jax version I'm using is 0.3.25. Please let me know regarding how to fix this,
Same here, it's ironic how both are internally developed still inconsistencies are there between the teams. Code was working before jax release 0.4.2
A quick hotfix is to use an older version of flax, which doesn't have newer Jax imports works for me.
Since, the vision-transformer doesn't specify a version in setup file, latest version of flax gets downloaded resulting in error.
Run the below cell before installing vision-transformer
@anuragithub Thank you for sharing the workaround. This works like a charm
This is now fixed with the new Flax release & #256
Sorry for the delay on this issue
(I saw it a bit late, and then it took some time to get Flax updated & make the updated dependencies work in Colab)