Where can I find docs for PJRT plugins?

137 views
Skip to first unread message

Joel Berkeley

unread,
Apr 23, 2024, 7:40:44 AMApr 23
to OpenXLA Discuss
Hi all,

Thank you for the work you're all doing to make ML compilers more accessible and composable. It means I can add support for all sorts of devices and compilers to my project with minimal effort.

I'm currently adding CUDA support via this PJRT plugin target. Are there docs anywhere to explain what setup is required for CUDA devices? For example I'm uncertain what to use for `create_options`. The tests indicate `visible_devices: {0}` but I'm seeing

> "no supported devices found for platform CUDA"

Without that argument, I see

> "Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR"

I'd like to figure out as much myself as possible, so if there are docs then please do send them my way. BTW my setup is working for the corresponding CPU target.

Thanks,
Joel

Jieying Luo

unread,
Apr 23, 2024, 12:12:11 PMApr 23
to Joel Berkeley, OpenXLA Discuss, Skye Wanderman-Milne
Hi Joel,

Glad that you are trying out PJRT plugins! For the CUDA PJRT plugin, you can refer to this initialize method (which is required by JAX, but some of the set up may be useful for you), in particular, the option is created by this method.

For the error you encountered, it is possible that these packages are not installed. You can also run `pip install jax[cuda12]` to install them.

Best,
Jieying

--
You received this message because you are subscribed to the Google Groups "OpenXLA Discuss" group.
To unsubscribe from this group and stop receiving emails from it, send an email to openxla-discu...@openxla.org.
To view this discussion on the web visit https://groups.google.com/a/openxla.org/d/msgid/openxla-discuss/c0c236a9-94bb-4064-9feb-186314cfbc25n%40openxla.org.
For more options, visit https://groups.google.com/a/openxla.org/d/optout.

Joel Berkeley

unread,
Apr 23, 2024, 5:39:19 PMApr 23
to OpenXLA Discuss, Jieying Luo, OpenXLA Discuss, Skye Wanderman-Milne, Joel Berkeley
thanks Jieying. I've tried out those changes but unfortunately am still seeing those errors. I am conscious I may need to reinstall ubuntu in case my cuda setup is broken.

Note I'm not using python, so I not sure how installing jax will interact with it. I did try to run my code within a python venv with jax installed it didn't appear to help. I do wonder if there's some way to tell what those packages correspond to in terms of CUDA toolkit/cudnn. I'll need to work out how to do this without python packages anyway so I don't need to ask my library users to.

Joel Berkeley

unread,
Apr 23, 2024, 5:41:21 PMApr 23
to OpenXLA Discuss, Joel Berkeley, Jieying Luo, OpenXLA Discuss, Skye Wanderman-Milne
You may have answered my primary question of whether there are docs for this stuff, with "no". That's fine, though unfortunate. These things take time. I wonder if perhaps I'll have the time to contribute those docs once I figure this out.

Joel Berkeley

unread,
Apr 23, 2024, 6:15:51 PMApr 23
to OpenXLA Discuss, Joel Berkeley, OpenXLA Discuss, Skye Wanderman-Milne
I have gone through every one of those libs in the jax package and found the corresponding nvidia install. I was missing nccl, but after letting nvidia in on my favourite breakfast cereal, it didn't help anyway

Joel Berkeley

unread,
Apr 26, 2024, 6:59:33 PMApr 26
to OpenXLA Discuss, Joel Berkeley, OpenXLA Discuss, Skye Wanderman-Milne
I ran it in an nvidia docker container, and got the same error, which more strongly suggests it's not caused by a faulty CUDA installation

Joel Berkeley

unread,
May 3, 2024, 8:47:56 AMMay 3
to OpenXLA Discuss, Joel Berkeley, OpenXLA Discuss, Skye Wanderman-Milne
OK it works with the tensorflow/build container. I can continue to investigate from there. My biggest blind spot atm is CUDA/cuDNN versions, and other package requirements. I am mostly stumbling  in the dark there
Reply all
Reply to author
Forward
0 new messages