For the Introduction to Graph Neural Nets with JAX/jraph notebook in colab, I get an error AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'.
I didn't get this before but I'm guessing it's because the versions of the pip installs have changed. Would it be possible to pin the library versions in cell 1 so that the notebook always work even a year from now?

For the Introduction to Graph Neural Nets with JAX/jraph notebook in colab, I get an error
AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'.I didn't get this before but I'm guessing it's because the versions of the pip installs have changed. Would it be possible to pin the library versions in cell 1 so that the notebook always work even a year from now?