Hi-- Sorry for the install troubles!
It is harder for me to check on an environment like this, but a few things:
1. Are you running in an isolated environment (virtualenv, conda, docker)? (the answer here should be "yes", otherwise you might be getting imports from a bunch of wild places!)
2. Are you using both jax and tensorflow in the same environment? If no, do things work with just `!pip install tfp-nightly[jax] jax jaxlib`? That should be a faster install, as well.
3. I ended up looking
at this to figure out the right numpy version to install. I seem to recall a similar error coming to yours when I was troubleshooting. Can you import numpy and like, multiply matrices from the environment you made? That would isolate the problem as being with a version mismatch for numpy.
Feel free to share more detail about how you're installing things, the output of `pip list`, and a more specific stack trace!
--C