Run with your colab code, I find that flax version is not right. With your requirements, it shows flax>=0.4.1, and it will download 0.6.0. However, flax with version 0.6.0 don't support flax.optim method, so flax with version 0.4.1 or 0.5.1 may more suitable.
Thanks for the heads-up.
I really need to fix #232 to make the repo work with the latest version.
Will update shortly.
Updated repo to latest Flax version in #245