Hi Joel,
I think maybe the key clarification here is that XLA's while is not in the "inner loop" of an optimizer function like 'adam' -- what George is referring to is the fact you can put a "while" loop around something like a neural net training step so that you could say things in XLA like "run this neural net step until the loss is < N" or something like "until eval accuracy is > N".
If you look at what IMO is a very simple/clean model builder like stax inside of JAX you can see this line here:
opt_state = opt_init(init_params)
for i in range(num_steps): # this could be an XLA while loop instead
opt_state = update(i, opt_state, next(batches))
trained_params = get_params(opt_state)
In this example first you ask the optimizer "hey can I get your initialization state", then you pump the neural net step function num_steps times as you shovel minibatches in. Each of these takes your optimizer state to optimizer state' which corresponds to an update of the parameters according to the optimizer function and any aux data for your optimizer. The optimizer is usually a pretty simple set of HLOs that are in their vector form on the parameters, e.g. you can see the multi-dimensional array ops in the case of adam here:
https://github.com/google/jax/blob/main/jax/example_libraries/optimizers.py#L414
Here's where XLA "while" can come in -- the Python-implemented `for` in the snippet above could be replaced with the XLA while loop with the kind of coarse-grained condition I mentioned. On GPU the while loop implementation "just" pumps the set of kernels that make up the "body" of the while to execute on a stream.
https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/service/gpu/while_thunk.cc#L75 Though hypothetically exposing more computation to the XLA high level optimizer could lead to better decisions in optimization and buffer allocation, launching from the host side as we see in the above snippet is usually a good default.
On a device like TPU where the execution stays resident on the device (instead of needing the host to pump it via a data-parallel-kernel-launch like mechanism), there's the ability to consume/preduce I/O from the host within the while loop, i.e. the device can pop the host-given "infeed" of a minibatch, run the XLA while body, and only terminate when the condition was met, without needing the host to hand-hold the device via kernel launches.
There's a class of models where this doesn't matter that much; e.g. trapping to the host to ask "should I keep going" and it checking a word in device memory and saying "yes please do" and shoveling the next minibatch in could be a) much smaller latency than the update() execution time and b) more convenient. But there are at least some times when that's not the case, and XLA wanted to handle those -- it's effectively a way of turning the XLA device computation which could be a simple function instead into a co-routine that can run indefinitely, and being able to do (pipelined) I/O to it from the outside world. (That being said, launching computations and doing I/O transfers can also be pipelined in schemes independent of a coroutine-and-queues style model, but this packaged it all together in a way that the compiler could reason about the memory requirements.)
Here's an example in JAX of infeeding to a loop in a device function that feeds back out just as a proof of concept for what I'm talking about:
https://github.com/google/jax/blob/de464fcf22c8bf7a2931182f8095bc01530df9fe/tests/infeed_test.py#L105 Like in the Python snippet above though, you could use host code to implement the while loop and launch XLA computations in a row using the host. I think de-facto that's the preferred way when you're not sensitive to this latency we're talking about, since then you get the full capability of the host to do whatever you want in-between your coarse grained XLA computations.
Random / deep "could be cool research" aside: I expect a persistent-kernel / infeed mechanism could be implemented on GPU devices as well and shave off some host round trips, but kernel launches are also an opportunity for the GPU to "reconfigure" itself for residency -- shared memory amounts, 2D register file access pattern (registers per thread), threads per block blocks in the grid -- so might need to implement something fancy like a "control" kernel (which could e.g. wait for infeed data to be present) that did sub-launches via a "dynamic parallelism" style feature. That being said, it'd probably be hard with the number of libraries we use on GPUs, host guidance of the kernel-launch process is a pervasive assumption in the ecosystem, and I don't think dynamic parallelism has a lower latency kernel launch than the host, it'd just be shaving off the round trips and any jank for the host to see kernel completion / dispatch. However, if people know any existing refs to work that builds this kind of system and provides their insights I'd be interested to read up!
Cheers,
- Leary