XLA support for FPGA backend

599 views
Skip to first unread message

Shubham Nema

unread,
Mar 21, 2019, 5:29:28 AM3/21/19
to xla...@googlegroups.com
Hello XLA developers team,

I am working on creating a backend for FPGA. Now, according to me, when we do session.run(), xla::cpu::Executable or xla::gpu::executable is called which executes the hlo_modules that contains graph.
Now, after the execution of instructions, result is stored in result_buffer. Is it possible to access or mutate the result_buffer? The class xla::ScopeShapedBuffer does not provide any function to access the actual data. The reason I would like to access and mutate the data is because once the FPGA computes the result, of say convolution, it should store back in CPU buffer memory so that the last function which is softmax() can be applied, compare that result with various classes and display final output on screen.
Further, I would have to provide the convolution inputs and weights (constant literals) to FPGA. For that I found xla::transfer_manager which can do it. However, my query is how those inputs and weights would be mapped to its corresponding convolution layer? 
I would request to kindly help me with the two above queries.

Thanking,

Best Regards,
Shubham

Sanjoy Das

unread,
Mar 21, 2019, 11:19:19 AM3/21/19
to Shubham Nema, XLA development
On Thu, Mar 21, 2019 at 2:29 AM Shubham Nema <shubha...@gmail.com> wrote:
> I am working on creating a backend for FPGA. Now, according to me, when we do session.run(), xla::cpu::Executable or xla::gpu::executable is called which executes the hlo_modules that contains graph.
> Now, after the execution of instructions, result is stored in result_buffer. Is it possible to access or mutate the result_buffer? The class xla::ScopeShapedBuffer does not provide any function to access the actual data.

Maybe I'm missing some nuance, but can't you use
ShapedBuffer::buffer(index)? xla::ScopeShapedBuffer inherits from
ShapedBuffer.

> Further, I would have to provide the convolution inputs and weights (constant literals) to FPGA. For that I found xla::transfer_manager which can do it. However, my query is how those inputs and weights would be mapped to its corresponding convolution layer?
> I would request to kindly help me with the two above queries.

Constants can be embedded directly in the XLA program that you
generate. Non-constant inputs can be parameters (XLA parameters are
basically function arguments) which can be passed in via the
`arguments` argument to xla::Executable::ExecuteOnStream etc.

It might be easier for us to help you if you described how you're
using XLA. Are you using it as part of TensorFlow, or in some other
manner?

-- Sanjoy

Shubham Nema

unread,
Mar 28, 2019, 8:27:48 AM3/28/19
to XLA development
Hello Sanjoy,

Apologies for the delayed response. I would like to tell you that ShapedBuffer::buffer(index) provides the buffer at that index. It does not gives the access to the actual data placed within the buffer. Also, if I use (void *)ptr to point to that buffer address and try to de-reference it, the bazel build throws error. Could you tell how to resolve that?

Constants can be embedded directly in the XLA program that you generate. 
>> You mean that weights will be treated as constants and will be stored in the generated XLA graph (HLO module) for each convolution layer which is essentially a cluster. Is that correct?

Are you using it as part of TensorFlow, ?
>> Yes. I am considering XLA as part of TensorFlow itself.

Justin Lebar

unread,
Mar 28, 2019, 1:27:38 PM3/28/19
to Shubham Nema, XLA development
ShapedBuffer::buffer() returns a stream_executor::DeviceMemoryBase.  This object represents memory that lives on a "device", that is, not on the CPU.

If you want to read the contents of the buffer on the CPU, you need to transfer it from the device to the CPU using e.g. Stream::ThenMemcpy.

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/72c2688d-f5e4-43b4-af47-8e2068dc7dd1%40googlegroups.com.
For more options, visit https://groups.google.com/d/optout.

Shubham Nema

unread,
Apr 3, 2019, 6:49:44 AM4/3/19
to XLA development
Thanks!!
I tried it but is giving segmentation fault when it tries to copy during runtime. Also, in my case the "device" is CPU itself as I am using XLA_CPU option as of now. Would it make sense to copy the data stored in CPU buffer to (void *)ptr using ThenMemcpy ?
To unsubscribe from this group and stop receiving emails from it, send an email to xla...@googlegroups.com.

Justin Lebar

unread,
Apr 11, 2019, 3:28:06 PM4/11/19
to Shubham Nema, XLA development
I don't have enough information to say what's going on, sorry.

If this is CPU-only, tools like GDB or address-sanitizer should be able to give full insight into what's happening.  There should be no magic.

To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.

To post to this group, send email to xla...@googlegroups.com.

Shubham Nema

unread,
Jun 19, 2019, 3:08:53 AM6/19/19
to XLA development
Hello,

While working on the backend, I would need the pointers to the input values of the top most HLO instruction: kParameter where 1st input is fed while execution in HLO Graph. I am able to get the shape, size, layout, offset etc. for the buffer assigned to kParameter instruction, however, I am unable to get the actual pointer which points to the input values stored in buffer of kParameter instruction. 
For kConstant HloInstruction, I understand that values are stored as literals within the HLO Module itself which can probably be accessed using .ConsumeValueOrDie() API (though pointers to literals also can be useful). 
I also thought of using TuplePointsToAnalysis class but ultimately it gives me back pointer to either LogicalBuffer or HloInstruction itself but not actual pointer to values stored in buffer.
I would request to kindly guide me for the above queries. 

Thanking,

Best Regrads,
Shubham

Justin Lebar

unread,
Jun 19, 2019, 3:20:35 AM6/19/19
to Shubham Nema, XLA development
Have a look at the BufferAssignment class; its job is to give you these pointers.

To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.

To post to this group, send email to xla...@googlegroups.com.

Shubham Nema

unread,
Jun 29, 2019, 6:02:27 AM6/29/19
to XLA development
Thanks. I actually found out the opaque void pointer to buffers in se::DeviceMemoryBase class.

Another peculiar thing I found is that XLA compiler seems to add HLO Reshape instruction after Parameter instruction even though both are already of same shape. Why do we need to add Reshape instruction here?

Thanking,

Regards,
Shubham

HLO_Graph.png


Justin Lebar

unread,
Jun 29, 2019, 11:51:03 AM6/29/19
to Shubham Nema, XLA development
I don't think the XLA compiler is adding these reshapes.  Whatever is calling XLA initially is probably adding them.  The XLA compiler will then *remove* them as it optimizes the program.

To unsubscribe from this group and stop receiving emails from it, send an email to xla-dev+u...@googlegroups.com.

To post to this group, send email to xla...@googlegroups.com.

Sanjoy Das

unread,
Jun 29, 2019, 3:24:40 PM6/29/19
to Justin Lebar, Shubham Nema, XLA development
I'd expect the extra reshapes are from https://github.com/tensorflow/tensorflow/blob/e70323993b7c3652f0a0f7bd3fa0cb79340f0ac6/tensorflow/compiler/tf2xla/xla_compiler.cc#L938

I don't think we ever add non-trivial reshapes on parameters in open source so I'd expect all of these reshapes to get optimized away.

-- Sanjoy

Shubham Nema

unread,
Aug 6, 2019, 7:58:48 AM8/6/19
to Sanjoy Das, Justin Lebar, XLA development
Hello,

Could you please help me with the below queries?
  • In GPU backend implementation, does Session.run() transfer HLO parameters (kParameters) to the target device each time when Sess.run() is called or only when Parameters are modified? If, it happens only when it is modified than how does backend recognizes that Parameters are modified.
  • Is it possible to set or assign the value in tf.constant(value, dtype, shape, name) at later point of time after it is declared?
Thanking,

Best Regards,
Shubham

Sanjoy Das

unread,
Aug 7, 2019, 10:41:00 AM8/7/19
to Shubham Nema, Justin Lebar, XLA development
On Tue, Aug 6, 2019 at 4:58 AM Shubham Nema <shubha...@gmail.com> wrote:
Hello,

Could you please help me with the below queries?
  • In GPU backend implementation, does Session.run() transfer HLO parameters (kParameters) to the target device each time when Sess.run() is called or only when Parameters are modified? If, it happens only when it is modified than how does backend recognizes that Parameters are modified.
session.run is a TF concept so it does not have a straightforward correspondence to HLO parameters.  There may be multiple XLA clusters in a single TF graph for instance.

Perhaps you can trace through a simple program to see when the host->device copies happen?
 
  • Is it possible to set or assign the value in tf.constant(value, dtype, shape, name) at later point of time after it is declared?
I don't think so.

-- Sanjoy

Shubham Nema

unread,
Aug 15, 2019, 9:34:30 PM8/15/19
to Sanjoy Das, Justin Lebar, XLA development
Thanks a lot Sanjoy.

In the CPU backend, we have a result_buffer (CpuExecutable::ExecuteComputeFunction) which points to the ToplevelOutputSlice of the buffer allocation. And the result is stored in this buffer. And since it is a void pointer, I commented the actual compute_function copied my dummy output (array of [2.48]) using memcpy, to that result_buffer . And if I try to print the result_buffer values within the backend by storing it in float pointer, it prints correctly. However, while it propogates to front-end python, it shows some of the values corrupted.

Is this issue anywhere related to python data-type vs TensorFlow data-type mismatch? Do you have any other idea?
 
image.png

Thanking,

Best Regards,
Shubham

Sanjoy Das

unread,
Aug 16, 2019, 10:04:36 AM8/16/19
to Shubham Nema, Justin Lebar, XLA development
It is hard to debug this over email.

As a first step, maybe try to use the various sanitizers (asan, ubsan, msan) to check that you don't have use-after-free or out of bounds accesses.

If the same address/offset gets corrupted in every case you can also try running the program under a debugger and setting a watchpoint.

-- Sanjoy

Shubham Nema

unread,
Aug 27, 2019, 7:30:46 AM8/27/19
to Sanjoy Das, Justin Lebar, XLA development
Thanks All !!
Backend for FPGA has been created. 

Vincent Mirian

unread,
Jul 29, 2020, 1:15:21 AM7/29/20
to XLA development
Hi,

I have a similar objective. I would like to define the TF Ops that my device/backend supports. I am looking for the TF Ops to XLA op mapping. My post regarding my request is: https://groups.google.com/forum/#!topic/xla-dev/Mo_73_bJHTA.

Any help or suggestions would be appreciated.

Thank you,
Vincent Mirian
Thanking,

Best Regards,
Shubham

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla...@googlegroups.com.
To post to this group, send email to xla...@googlegroups.com.

--
You received this message because you are subscribed to the Google Groups "XLA development" group.
To unsubscribe from this group and stop receiving emails from it, send an email to xla...@googlegroups.com.
To view this discussion on the web visit https://groups.google.com/d/msgid/xla-dev/CABBcqdHfu3vi-QzkVKO8-F28pb9CWs0XZtQ%3DuKeHDGhhRUpgbA%40mail.gmail.com.
Reply all
Reply to author
Forward
0 new messages