interacting with xla_client via tensorflow

162 views
Skip to first unread message

James Connolly

unread,
May 19, 2020, 12:21:04 AM5/19/20
to XLA development
Hey Folks,

I have a fork of tensorflow v2.1.0 I'm iterating off of, and I'd like to access the xla_client python library (tensorflow/compiler/xla/python/xla_client.py) for some python tests. 

By default, xla_client isn't a dependency of the pip package target. To try remedy this, I've added "//tensorflow/compiler/xla/python:xla_client" as a dependency to "//tensorflow/python/compiler/xla:compiler_py". 

Problem is, when trying to import xla_client (from tensorflow.compiler.xla.python import xla_client) there a missing symbol error. 

This poses two questions:

  1. Is there a reason why xla_client is not included in tensorflow releases? When searching around, the recommended path seems to be importing xla_client from jaxlib
  2. Does anyone have a flow for including xla_client in their tensorflow release?


Peter Hawkins

unread,
May 19, 2020, 9:16:36 AM5/19/20
to James Connolly, XLA development

Hi...

The XLA Python bindings aren't included in TensorFlow; JAX is their main user and they are packaged on Pypi as jaxlib, as you have observed. They are really only in the TensorFlow source tree because XLA is in the TensorFlow tree. To my knowledge, there hasn't been any discussion about including them in TensorFlow itself. I also caution you they also aren't a completely stable API (just as XLA's C++ API isn't stable).

If the goal is just for local testing via "bazel test" inside the TensorFlow tree, I believe this can work if you build with --config=monolithic; at some point in the past xla_client_test_cpu in the same directory worked via "bazel test", although it's not a configuration we routinely test ourselves. It's possible the build rules need updating because we don't test this in our CI builds, and there are some details to do with linking that are a little different in opensource as opposed to internal to Google where we do run these tests.

It is possible to build a custom Jaxlib that includes your modifications to XLA, see: https://jax.readthedocs.io/en/latest/developer.html#building-jaxlib-from-source and note you can change the JAX WORKSPACE file to point to a TensorFlow tree of your choice instead of a fixed Github hash. You may need to rebase your fork off a more recent TensorFlow checkout for this to work though. You can then access the Python bindings as jax.lib.xla_client.

We've also at points in the past talked about perhaps splitting the XLA Python bindings out of jaxlib into a pypi package of their own.

Hope that helps,
Peter


--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/5b52f805-5280-4812-b359-3d319e669f1a%40googlegroups.com.

James Connolly

unread,
May 26, 2020, 3:57:07 PM5/26/20
to XLA development
Hey Peter,

Thanks so much for this insight. Building our own jaxlib seems like the best path moving forward. Rebasing off a newer version of TF would ideal, but v2.2.0 has problems building with debug symbols.. issues for another forum though :)

Thanks,
James 
To unsubscribe from this group and stop receiving emails from it, send an email to xla...@googlegroups.com.
Reply all
Reply to author
Forward
0 new messages