Roadmap about adding dynamic shape support in MLIR HLO dialect

549 views
Skip to first unread message

Jun Yang

unread,
Jan 2, 2020, 10:44:11 AM1/2/20
to MLIR
Hi ,

In this doc 

It is mentioned that there is a plan to add dynamic support in XLA/HLO for TF/XLA bridge based on MLIR. Is there any concrete plan or roadmap to be shared? Currently we are also working on to implement some PoC to support dynamic shape for code generation and would like to keep sync with MLIR community. 

At present we think there might be two possible options to add dynamic support in XLA/HLO within MLIR:
1. Add another dialect(maybe named as DHLO?) and embed the dynamic shape information into the newly introduced dialect representation. Then leverage the Stencil/Linalg dialect to fulfill the fusion&codegen work flow(we could not directly re-use the existing XLA fusion&codegen infras as HLO dialect). 
Pros for this approach is that we believe that the best performance could only be gained with fixed shape. So by keeping the original static-shaped HLO dialect, we leave the high-performance fixed-shape codegen untouched. During the JIT compilation process, if the dynamic shape issue is too severe, we could just switch the IR graph from HLO dialect back to DHLO dialect(this should not be difficult to achieve) and fall back to DHLO dialect for dynamic-shape codegen. 

2. Directly extend the HLO dialect to add support for dynamic shape. Regarding to the fusion&codegen part, we could either leverage the Stencil/Linalg dialect or just refer to the existing XLA fusion&codegen implementation to add the corresponding support. 
Pros for this approach is that we avoid the cost of introducing another dialect. The cons is that we need to take care to not to lose the performance benefit with static-shaped codegen. 

Personally I prefer solution 1, so in the following I just describe further thinkings regarding to Solution 1. 

Another thing deserve attention is that for dynamic shape and static shape, I think there should be some optimization to be shared, we may need to do some extra work to provide such support(such as optimization pass supporting both DHLO and HLO dialect, maybe we could just introduce another dialect called XHLOOpt and place the corresponding shape-agnostic optimization pass there)

In summary, there might be several potential execution flow with adding dynamic shape support:

a). DHLO-->XHLOOpt-->Linalg/Stencil--->LLVM
b). HLO--->XHLOOpt--->(HLO dialect optimization)--->Linalg/Stencil--->LLVM 
c). HLO---->(round trip to XLA world and directly use XLA existing optimization stuffs)
d). HLO--->(check to see that dynamic shape is too severe to fall back to HLO)--->DHLO--->XHLOOpt--->Linalg/Stencil--->LLVM.

Any comments and suggestions are highly welcome.

Thanks



Mehdi AMINI

unread,
Jan 2, 2020, 10:52:10 AM1/2/20
to Jun Yang, MLIR
Hi Jun,


There are already two dialects prototyped: HLO and LHLO (late-HLO). The latter is using memref instead of tensor to model explicitly allocated buffers.
Both of them support dynamic shapes and the lowering from TensorFlow dialect to HLO dialect are written to support dynamic shapes (see the work going on here: https://github.com/tensorflow/tensorflow/commits/master/tensorflow/compiler/mlir/xla/transforms )

We're missing the HLO->LHLO transforms, but we already have some GPU codegen using LHLO->Linalg (see here https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/xla/transforms/lhlo_legalize_to_gpu.cc ).

Best,

-- 
Mehdi


--
You received this message because you are subscribed to the Google Groups "MLIR" group.
To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/CACgUC_DdiKhfj1nDYQm5KD59%3DTfyei3cqaHJyky3JQ3yYoMZxQ%40mail.gmail.com.

Nicolas Vasilache

unread,
Jan 2, 2020, 11:03:05 AM1/2/20
to Mehdi AMINI, Jun Yang, MLIR
Hello Jun and Mehdi,

To complement what Mehdi wrote I'd like to attract your attention on the following point.

One thing to note is that Linalg is adding first class support for tensors in https://reviews.llvm.org/D72022

This has a number of implications in some of the transformations that are traditionally done at the level of the HLO dialect.
In particular, everything related to trivial fusion of pointwise operators can be done immediately using the region.
This avoids the need for the current, more cumbersome and phase-ordered, flow that does:
1. mark fusion with XLA fusion nodes,
2. allocate buffers for everything
3. convert to Linalg
4. apply fusion in Linalg
5. perform an analysis and remove temporary buffers that have been fused.

Note that step 4. may not necessarily do what one wants at step 1. since we are talking about different systems that are not really designed to talk to each other.

Instead, this can be replaced by:
1. apply fusion of ops using regions

Temporary buffers never get materialized or anything.
This becomes especially handy when implicit of explicit broadcast semantics are involved: some things are trivial to fuse at the level of Linalg on tensors and all the unnecessary intermediate memory is never allocated.

There are many other implications on the type of transforms that become available at this level (hint: look at the TASO compiler) but I only listed the most obvious one.

In my mind the codegen path where things are the most natural is:

User
-> Language / Framework 
-> HLO + Linalg on tensors 
-> LHLO + Linalg on buffers 
(note that buffer allocation in Linalg on tensors -> Linalg on buffers can be very progressive intermixing ops with both tensor and buffers arbitrarily)
-> Affine/StructuredControlFlow (still named Loops atm ..)
-> backends

Different transformations apply at each level. 



--
N

Jun Yang

unread,
Jan 2, 2020, 3:42:05 PM1/2/20
to MLIR
Hi Mehdi,

Thanks for the prompt reply.

Regarding to what you said "Both of them support dynamic shapes and the lowering from TensorFlow dialect to HLO dialect are written to support dynamic shapes"
I am a little bit curious how to understand the meaning of "support dynamic shapes". Since in my observation, currently TF2HLO dialect transformation just behaves the same as the original HLO. Let me take one example by showing the code:

// talk is cheap, show the code----begins
PatternMatchResult matchAndRewrite(TF::Conv2DBackpropInputOp op,
                                     PatternRewriter &rewriter) const override {
    // Unpack all of the attributes.
    tensorflow::TensorFormat data_format;
    if (!FormatFromString(op.data_format().str(), &data_format)) {
      return matchFailure();
    }
    tensorflow::Padding padding;
    if (!GetPaddingFromString(op.padding().str(), &padding).ok())
      return Pattern::matchFailure();

    auto out_backprop_ty =
        op.out_backprop()->getType().dyn_cast<RankedTensorType>();
    if (!out_backprop_ty || !out_backprop_ty.hasStaticShape())
      return matchFailure();
    ArrayRef<int64_t> out_backprop_shape = out_backprop_ty.getShape();
    auto filter_ty = op.filter()->getType().dyn_cast<RankedTensorType>();
    if (!filter_ty || !filter_ty.hasStaticShape()) return matchFailure();
    ArrayRef<int64_t> filter_shape = filter_ty.getShape();
    int num_spatial_dims = 2;
    Location loc = op.getLoc();

    int num_dims = num_spatial_dims + 2;
    int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format);
    int feature_dim =
        tensorflow::GetTensorFeatureDimIndex(num_dims, data_format);

    DenseIntElementsAttr input_shape_attr;
    if (!matchPattern(op.input_sizes(), m_Constant(&input_shape_attr)) ||
        input_shape_attr.getType().getRank() != 1) {
      return matchFailure();
// talk is cheap, show the code----ends

From the above code snippet, it can be seen that when we do conversion from tf dialect to HLO dialect, there is still an inherent static shape constraint(see the bold part).  At least for some(or lots of) TF operations. 

However, from dialect representation perspective, I also think that HLO dialect might be capable to represent its input/output with dynamic shape rather than just static ones since the input and output is represented as HLO_Tensor alike things:
// talk is cheap, show the code----begins
// Any integer tensor types
def HLO_IntTensor : TensorOf<[HLO_Int]>;

// Any floating-point tensor types
def HLO_FpTensor : TensorOf<[AnyFloat]>;

def HLO_PredTensor : TensorOf<[HLO_Pred]>;

def HLO_Tensor : TensorOf<[AnyFloat, AnyInteger, AnyComplex]>;

def HLO_ComplexTensor : TensorOf<[AnyComplex]>;

def HLO_Tuple : NestedTupleOf<[HLO_Tensor, HLO_Token]>;

def HLO_TensorOrTuple : AnyTypeOf<[HLO_Tensor, HLO_Tuple]>;
// talk is cheap, show the code----ends
So from my point of view, there is an inconsistency here. 
Could you please help elaborate a little bit more?

Thanks

在 2020年1月2日星期四 UTC+8下午11:52:10,Mehdi AMINI写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.

Jun Yang

unread,
Jan 2, 2020, 3:57:02 PM1/2/20
to MLIR
Hi Hicolas

Nice to see the discussion in the LLVM discussion thread.

It is interesting that we add Tensor support in the Linalg dialect to ease the fusion work.

One more thing I think that deserve further discussion.

As to the codegen path you mentioned:

-> Language / Framework 
-> HLO + Linalg on tensors 
-> LHLO + Linalg on buffers 
(note that buffer allocation in Linalg on tensors -> Linalg on buffers can be very progressive intermixing ops with both tensor and buffers arbitrarily)
-> Affine/StructuredControlFlow (still named Loops atm ..)
-> backends

There is intermix of HLO + Linalg and LHLO + Linalg dialects during the conversion process.

I think one possible reason that we need this intermix is that due to the potential limitation of Linalg, it may not support all the fusion related stuffs at present, so for some sub-graphs we could directly leverage Linalg, while for other left sub-graphs we have to resort to HLO/LHLO's 
own optimization support. What is your point of view? 

Thanks 
Jun

在 2020年1月3日星期五 UTC+8上午12:03:05,Nicolas Vasilache写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.

--
You received this message because you are subscribed to the Google Groups "MLIR" group.
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.


--
N

Nicolas Vasilache

unread,
Jan 2, 2020, 4:04:40 PM1/2/20
to Jun Yang, MLIR
On Thu, Jan 2, 2020 at 3:57 PM Jun Yang <yangj...@gmail.com> wrote:
Hi Hicolas

Nice to see the discussion in the LLVM discussion thread.

It is interesting that we add Tensor support in the Linalg dialect to ease the fusion work.

One more thing I think that deserve further discussion.

As to the codegen path you mentioned:

-> Language / Framework 
-> HLO + Linalg on tensors 
-> LHLO + Linalg on buffers 
(note that buffer allocation in Linalg on tensors -> Linalg on buffers can be very progressive intermixing ops with both tensor and buffers arbitrarily)
-> Affine/StructuredControlFlow (still named Loops atm ..)
-> backends

There is intermix of HLO + Linalg and LHLO + Linalg dialects during the conversion process.

I think one possible reason that we need this intermix is that due to the potential limitation of Linalg, it may not support all the fusion related stuffs at present, so for some sub-graphs we could directly leverage Linalg, while for other left sub-graphs we have to resort to HLO/LHLO's 
own optimization support. What is your point of view? 

Indeed, this is conservative because HLO is already well established, expressive and has had a lot of engineering support.
On the other hand, it does not have true custom ops or dynamic support and breaks everything in small pieces.
As things progress and we have more data it will become clearer what type of representation / algorithms will be the most useful. 

 
To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/cffe416e-2a62-4f85-9b56-5e51653bfbce%40tensorflow.org.


--
N

Jun Yang

unread,
Jan 2, 2020, 5:23:48 PM1/2/20
to Nicolas Vasilache, MLIR
Got it and this does make sense to me.

Actually we keep the same design objective. In our mind, we would like to add dynamic support only for those scenarios which have incontrovertible request. For other scenarios, we will resort 
to existing HLO static shape support as much as  possible. Also I think this is one of the major benefits brought by MLIR since it makes inter-mix of newly-introduced things and legacy things cooperate in a more coherent way. 

Jun

--

Jun Yang

unread,
Jan 3, 2020, 10:08:48 PM1/3/20
to Mehdi AMINI, MLIR
Another question regarding to "HLO dialect already support dynamic shape".

In my understanding, even from the representation perspective, HLO dialect still doesn't support full dynamic shape semantics.

Let me give one concrete example:
Here is the table gen definition of HLO Slice instruction: 
def HLO_SliceOp: HLO_Op<
      "slice",
      [NoSideEffect, SameOperandsAndResultElementType,
       AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
  let arguments = (ins
    HLO_Tensor:$operand,
    I64ElementsAttr:$start_indices,
    I64ElementsAttr:$limit_indices,
    I64ElementsAttr:$strides
  );

  let results = (outs HLO_Tensor);

  let builders = [OpBuilder<
    "Builder *builder, OperationState &result, Value operand, "
    "DenseIntElementsAttr start_indices, DenseIntElementsAttr limit_indices, "
    "DenseIntElementsAttr strides"
  >];

  let extraClassDeclaration = [{
    // Infers output type for given operand and attributes. Result type is
    // unranked if any of the attributes is illegal.
    static Type InferOutputTypes(Builder *builder, Value operand,
                                 DenseIntElementsAttr start_indices,
                                 DenseIntElementsAttr limit_indices,
                                 DenseIntElementsAttr strides);
  }];
}


And the corresponding TensorFlow Slice definition is as following: 
op {
  name: "Slice"
  input_arg {
    name: "input"
    type_attr: "T"
  }
  input_arg {
    name: "begin"
    type_attr: "Index"
  }
  input_arg {
    name: "size"
    type_attr: "Index"
  }
  output_arg {
    name: "output"
    type_attr: "T"
  }
  attr {
    name: "T"
    type: "type"
  }
  attr {
    name: "Index"
    type: "type"
    allowed_values {
      list {
        type: DT_INT32
        type: DT_INT64
      }
    }
  }
}


From the above code snippet, it can be seen that for Slice op, HLO dialect has limited representation capability against the TF operation semantics.

Since for begin/size, TF operation specify them as Tensors, while HLO dialect specify them as concrete scalar value. Thus dynamic shape representation capability is lost 
to a certain extent. 

I haven't gone through all the definitions for HLO dialects, but I suspect that there may be other operations having the same issue. 

Correct me if my understanding is wrong.
--

Jack John

unread,
Jan 7, 2020, 3:36:57 AM1/7/20
to MLIR
Since for begin/size, TF operation specify them as Tensors, while HLO dialect specify them as concrete scalar value. Thus dynamic shape representation capability is lost 
to a certain extent. 

TF define begin and size attribute in as int32/int64 list while HLO_SliceOp in mlir define them as I64ElementsAttr for start_indices and limit_indices, I64ElementsAttr  in mlir represent a vector or a tensor value
I don't quite understand where you get so called "concrete scale value" , maybe a misunderstanding?

在 2020年1月4日星期六 UTC+8上午11:08:48,Jun Yang写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.


--

在 2020年1月4日星期六 UTC+8上午11:08:48,Jun Yang写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.


--

在 2020年1月4日星期六 UTC+8上午11:08:48,Jun Yang写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.


--

在 2020年1月4日星期六 UTC+8上午11:08:48,Jun Yang写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.


--

在 2020年1月4日星期六 UTC+8上午11:08:48,Jun Yang写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.


--

在 2020年1月4日星期六 UTC+8上午11:08:48,Jun Yang写道:
To unsubscribe from this group and stop receiving emails from it, send an email to ml...@tensorflow.org.

Jun Yang

unread,
Jan 8, 2020, 4:40:42 AM1/8/20
to Jack John, MLIR
Hi Jack,

Yep, I think you are right and I just missed the details.

Thanks

To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/b07e9509-3e0d-4910-80f0-06591082808c%40tensorflow.org.


--

Kai Zhu

unread,
Jan 9, 2020, 1:00:19 AM1/9/20
to MLIR
Hi, Jack

In my understanding, I64ElementsAttr is some kind of "Attribute" which is "for specifying constant data on Ops". This means that the slice_size must be compile time constant.
A "dynamic shape representation" means that the slice_size can be calculated with some other Ops at runtime, which is quite usual in a TensorFlow graph. 
After checking the rationale of 'Attributes' in MLIR, I still believe that the HLO_Dialect cannot support dynamic shape.



...

Lei Zhang

unread,
Jan 9, 2020, 9:07:08 PM1/9/20
to Kai Zhu, MLIR
For slice specifically, there are both xla_hlo.slice op and xla_hlo.dynamic-slice op. The latter can support dynamic beginning indices. Generally, proper dynamic shape support requires changes to many components in the whole compilation pipeline. For the HLO dialect, we make sure ops are not required to be statically shaped anymore in verification wherever possible; this lays the foundation for dynamic shape support. You can actually find tests on ops taking in unranked tensors or ranked tensors with dynamic dimensions, e.g., for transpose: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/xla/tests/ops.mlir#L494. But yes AFAIU op definitions and semantics are not radically adjusted specifically to cater to dynamic shape; meaning if the existing HLO op's definition takes in an compilation-time constant, the corresponding HLO op definition in the dialect will respect that for now. This makes it easier to bring up the dialect and evaluate parity. On the lowering side, we are also paying attention to dynamic shapes and many patterns already support dynamic shaped TensorFlow ops instead of directly rejecting them. You can find all the tests with shapes containing * or ?, for example, unpack: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L2780. Dynamic shape support is certainly in the work, as the HLO dialect and its lowering themselves. :)

Thanks,
Lei


To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/7963ebec-995c-4a87-936f-d84d3a68686c%40tensorflow.org.

Kai Zhu

unread,
Jan 10, 2020, 4:29:41 AM1/10/20
to MLIR
Hi, Lei

Are you saying that you are already working on the fully dynamic shape solution?
We are also planning for the same thing since the dynamic shape issue has been a critical issue for quite a long time in our scenario.
We agree that this brings difference for almost every stage in the whole flow.

Are there any existed RFCs or docs that describe your whole solution? 


...

Lei Zhang

unread,
Jan 10, 2020, 10:56:03 AM1/10/20
to Kai Zhu, MLIR
Sorry for the confusion! To be clear, no, I myself am not working on a full dynamic shape solution. My point was that HLO dialect and its lowering from the TensorFlow dialect is already taking dynamic shape into consideration; it takes time to have the dialect and its lowering fully developed.

To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/41985502-ea98-4a6f-ab1d-0a9f9930cb70%40tensorflow.org.

Xiaoyong Liu

unread,
Jan 10, 2020, 4:16:03 PM1/10/20
to Lei Zhang, Kai Zhu, MLIR
just want to clarify the "lowering fully developed". When we are talking about "HLO dialect supports dynamic shape for certain ops", the meaning here for hlo dialect is for exporting the hlo dialect, right? If yes, you have to enhance the current xla. If no, how is your proposal to hlo dialect lowering.

As my understanding, hlo dialect means definition, convertion and xla implementation.

-Xiaoyong

Stella Laurenzo

unread,
Jan 10, 2020, 4:30:01 PM1/10/20
to Xiaoyong Liu, Jacques Pienaar, Lei Zhang, Kai Zhu, MLIR
I am not aware of any existing public RFC for this, although, there has been a fair amount of internal planning and laying foundations for some time. I know this isn't helpful to the community, and we should probably try to rebase any assets and plans we have about this to be publicly available (this is one of those things that it was easy to fall into some unhelpful communication patterns about, given that the topic is neither wholly MLIR nor wholly TensorFlow related.

Purely at the level of making it possible to express HLO-based computations with dynamic shapes, I expect that there are a handful of ops which need to be added. Following the pattern of Slice vs DynamicSlice, I would prefer to see "dynamic" versions of these added as needed (versus changing the semantics of existing static ops).

My team (IREE) will be pushing on this over the next few months as well, and I am definitely +1 on finding a way to flesh the design out publicly. It would be great if we could start with an RFC that outlines the state of where things are at now and what we can see that needs to change. I suspect that +Jacques Pienaar needs to be fairly involved in that, and I know that he has been very busy with the migration of MLIR->LLVM.

Xiaoyong Liu

unread,
Jan 10, 2020, 6:11:39 PM1/10/20
to Stella Laurenzo, Jacques Pienaar, Lei Zhang, Kai Zhu, MLIR
I'd like to work with you and Jacques for this. The dynamic shape supporting topic was in xla and now to mlir for a while. The unclear roadmap and opaque definition make the discuss difficult to closure. 
We'd better use a public RFC to clear items such as:
1. the design principle
2. the scope: the ops and its type, and clear which is included and which is not, and why
3. the implementation method: in exported hlo or lowering through mlir dialects
4. performance and functional trade-off for cpu, gpu & maybe tpu



Stella Laurenzo

unread,
Jan 10, 2020, 6:15:42 PM1/10/20
to Xiaoyong Liu, Jacques Pienaar, Lei Zhang, Kai Zhu, MLIR
On Fri, Jan 10, 2020 at 3:11 PM Xiaoyong Liu <xyli...@gmail.com> wrote:
I'd like to work with you and Jacques for this. The dynamic shape supporting topic was in xla and now to mlir for a while. The unclear roadmap and opaque definition make the discuss difficult to closure. 
We'd better use a public RFC to clear items such as:
1. the design principle
2. the scope: the ops and its type, and clear which is included and which is not, and why
3. the implementation method: in exported hlo or lowering through mlir dialects
4. performance and functional trade-off for cpu, gpu & maybe tpu


+1 - Let's see what Jacques has to say...

Geoffrey Martin-Noble

unread,
Jan 11, 2020, 1:14:53 AM1/11/20
to Stella Laurenzo, Xiaoyong Liu, Jacques Pienaar, Lei Zhang, Kai Zhu, MLIR, Ace
I just wanted to close the loop because there was a related discussion in another thread recently (+Ace from that thread). A lot of the existing MLIR HLO dialect op definitions were first written with the assumption of everything having static shapes because that was the case for the HLO proto representation. So you may very well find existing evidence of things that assume static shapes. The type verifiers were relaxed, but there are still some pieces that may be left over from before that. In some places, that's a necessary part of the op definition, like the particular slice op mentioned, and I'm +1 on keeping that separate when it's a core part of the op definition and having a separate dynamic alternative. As Jun said before, we can probably achieve better optimization in the static shape case where it's possible. But there may be some places (and I wrote some of the, sorry!) where a verification or something assumes that an operand has a static shape and doesn't really need to. I think it's fair to say that those are essentially "historical artifacts" at this point and should be fixed :-)

Xiaoyong Liu

unread,
Jan 12, 2020, 1:02:06 AM1/12/20
to Geoffrey Martin-Noble, Stella Laurenzo, Jacques Pienaar, Lei Zhang, Kai Zhu, MLIR, Ace
Thank you Geoffrey. I'm not sure whether I'm understanding this correctly that you are saying that hlo dialect is designed and will be maintained in such a way that all shape will be staticlly, or can be inferred to a static value with certain amounts of re-inferring in a application. Pls kindly comment this before we close this loop. Few things I think we are in the same page, 1. static shape may give more optimization oppertunity 2. the related discussion you are referred only talking about reshape, and few stuff left

The discussion behind dynamic shape discussion is not all about performance. It's about the scope that hlo, or xla, will cover. If dynamic shape is not going to be covered well, I think that's fine from design choice point of view since static shape has its advantage. 
But considerring a fully dynamic shape friendly codegen ration is really important.We may consider other way to implement it such as using mlir to represent a new dynamic shape friendly rerepenstation, may not perferct for performance, and lowering directly to llvm, maybe through a kind-of late-hlo, ligalg,affine then to llvm. Let's make this clear.  Otherwise, I believe this is not the end of this kind of disccusion.

The quetion here is that how hlo dialect will cover dynamic shape scenario and to what level. If fully dynamic shape will be supported, what's the machanism there for the existed transformation pass , buffer allocation and  runtime etc. I know this is very important for gpu or tpu and kind of accelerator. 
With this being answerred, we can deliver a rfc to support dynamic shape compilation for either enhancing the existed hlo dialect or bringingg a new one to work with the hlo dialect. 

Mehdi AMINI

unread,
Jan 12, 2020, 2:18:12 PM1/12/20
to Xiaoyong Liu, Lei Zhang, Kai Zhu, MLIR
On Fri, Jan 10, 2020 at 1:16 PM Xiaoyong Liu <xyli...@gmail.com> wrote:
just want to clarify the "lowering fully developed". When we are talking about "HLO dialect supports dynamic shape for certain ops", the meaning here for hlo dialect is for exporting the hlo dialect, right? If yes, you have to enhance the current xla. If no, how is your proposal to hlo dialect lowering.

Lowering refers to the conversion from TF dialect operations to HLO dialect operations.
 

As my understanding, hlo dialect means definition, convertion and xla implementation

XLA is not part of the HLO dialect definition. "HLO dialect" refers purely to the MLIR side independently of XLA, we don't plan to change XLA itself.
You can imagine a codegen path that is fully independent of XLA using only MLIR component. The HLO dialect is a stepping stone that allows us to rely on the proven XLA techniques and experience to build an independent MLIR codegen path.

There is no detailed plan at the moment because our first milestones have been focused on re-using the XLA Codegen path and reaching parity with the existing bridge. The first components we're replacing are 1) the GraphTransformations passes that implement the TensorFlow graph transformation to extract a cluster of computation to be compiled with XLA, and 2) the set of kernels that emit HLO for each of the TensorFlow Op.
As such we haven't prioritized an end-to-end path that supports dynamic shape at the moment, as none of the existing use-cases using XLA requires it and we're limited by XLA anyway to reach our current milestones. However some experimentations have been conducted, using LHLO -> Linalg conversion for now, and many folks are actively playing with alternatives in this domain (mostly targeting CPUs and GPUs right now).

-- 
Mehdi

 

Stella Laurenzo

unread,
Jan 13, 2020, 1:49:21 PM1/13/20
to Mehdi AMINI, Xiaoyong Liu, Lei Zhang, Kai Zhu, MLIR
On Sun, Jan 12, 2020 at 11:18 AM Mehdi AMINI <joke...@gmail.com> wrote:


On Fri, Jan 10, 2020 at 1:16 PM Xiaoyong Liu <xyli...@gmail.com> wrote:
just want to clarify the "lowering fully developed". When we are talking about "HLO dialect supports dynamic shape for certain ops", the meaning here for hlo dialect is for exporting the hlo dialect, right? If yes, you have to enhance the current xla. If no, how is your proposal to hlo dialect lowering.

Lowering refers to the conversion from TF dialect operations to HLO dialect operations.
 

As my understanding, hlo dialect means definition, convertion and xla implementation

XLA is not part of the HLO dialect definition. "HLO dialect" refers purely to the MLIR side independently of XLA, we don't plan to change XLA itself.
You can imagine a codegen path that is fully independent of XLA using only MLIR component. The HLO dialect is a stepping stone that allows us to rely on the proven XLA techniques and experience to build an independent MLIR codegen path.

There is no detailed plan at the moment because our first milestones have been focused on re-using the XLA Codegen path and reaching parity with the existing bridge. The first components we're replacing are 1) the GraphTransformations passes that implement the TensorFlow graph transformation to extract a cluster of computation to be compiled with XLA, and 2) the set of kernels that emit HLO for each of the TensorFlow Op.
As such we haven't prioritized an end-to-end path that supports dynamic shape at the moment, as none of the existing use-cases using XLA requires it and we're limited by XLA anyway to reach our current milestones. However some experimentations have been conducted, using LHLO -> Linalg conversion for now, and many folks are actively playing with alternatives in this domain (mostly targeting CPUs and GPUs right now).

As you say in the last statement, I think it is important to be explicit that there are multiple "we's" here, even within Google. For our/IREE CPU and GPU work, we will be wanting to take some measured steps soon to relax the static shape assumptions in the xla_hlo dialect. This will be in combination with both working out the frontend issues (how to plumb dynamic dimensions from the call-sites) and backend (how to lower through LinAlg to do codegen).

From my side, we're just starting to talk about next steps. I was thinking about starting with a basic CNN and Sequence model (say ResNet and a simplified Transformer or LSTM), making the batch and/or sequence dimensions dynamic and starting to thread that through at each level. I don't have a holistic list of issues I anticipate encountering along the way but do have suspicions. Primarily, I would like to see the frontend/e2e case setup well so that we can see the rest. I expect that even taking this measured step is going to be quite a large amount of work, requiring taking opinions about runtime, allocation, etc which may not generalize to all applications.
 

Mehdi AMINI

unread,
Jan 13, 2020, 2:43:54 PM1/13/20
to Stella Laurenzo, Kai Zhu, Lei Zhang, MLIR, Xiaoyong Liu
On Mon, Jan 13, 2020 at 10:49 AM Stella Laurenzo <laur...@google.com> wrote:


On Sun, Jan 12, 2020 at 11:18 AM Mehdi AMINI <joke...@gmail.com> wrote:


On Fri, Jan 10, 2020 at 1:16 PM Xiaoyong Liu <xyli...@gmail.com> wrote:
just want to clarify the "lowering fully developed". When we are talking about "HLO dialect supports dynamic shape for certain ops", the meaning here for hlo dialect is for exporting the hlo dialect, right? If yes, you have to enhance the current xla. If no, how is your proposal to hlo dialect lowering.

Lowering refers to the conversion from TF dialect operations to HLO dialect operations.
 

As my understanding, hlo dialect means definition, convertion and xla implementation

XLA is not part of the HLO dialect definition. "HLO dialect" refers purely to the MLIR side independently of XLA, we don't plan to change XLA itself.
You can imagine a codegen path that is fully independent of XLA using only MLIR component. The HLO dialect is a stepping stone that allows us to rely on the proven XLA techniques and experience to build an independent MLIR codegen path.

There is no detailed plan at the moment because our first milestones have been focused on re-using the XLA Codegen path and reaching parity with the existing bridge. The first components we're replacing are 1) the GraphTransformations passes that implement the TensorFlow graph transformation to extract a cluster of computation to be compiled with XLA, and 2) the set of kernels that emit HLO for each of the TensorFlow Op.
As such we haven't prioritized an end-to-end path that supports dynamic shape at the moment, as none of the existing use-cases using XLA requires it and we're limited by XLA anyway to reach our current milestones. However some experimentations have been conducted, using LHLO -> Linalg conversion for now, and many folks are actively playing with alternatives in this domain (mostly targeting CPUs and GPUs right now).

As you say in the last statement, I think it is important to be explicit that there are multiple "we's" here, even within Google.

Right, since we are on the Tensorflow mailing list here, I am overfitting to the work happening under the TensorFlow umbrella :)
The work happening in satellite projects is relevant though, we should probably improve our roadmap sharing to stay in closer sync!
Reply all
Reply to author
Forward
0 new messages