"tf.Reshape" does not support the dynamic shape reshaping in TF dialect

193 views
Skip to first unread message

mofheka

unread,
Mar 15, 2021, 8:13:08 AM3/15/21
to MLIR
For example:

%222 = "tf.Const"() {device = "", value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%223 = "tf.Reshape"(%arg0, %222) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>

The "tf.Reshape" in above dialect would be generated a %223 with shape<1xi32>, but not <?xi32>。
In the raw TF graph def, the node is described as below. the op: "Const" has "int_val: -1" attribution, and the op: "Reshape" receive op: "Const" as the input of attribution key: "Tshape" for dynamic shape tensor:

node {
  name: "model/dense_features_43/177_bucketized_embedding/Reshape/shape"
  op: "Const"
  attr {
    key: "dtype"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "value"
    value {
      tensor {
        dtype: DT_INT32
        tensor_shape {
          dim {
            size: 1
          }
        }
        int_val: -1
      }
    }
  }
}

node {
  name: "model/dense_features_43/177_bucketized_embedding/Reshape"
  op: "Reshape"
  input: "model/dense_features_43/177_bucketized_embedding/Tile"
  input: "model/dense_features_43/177_bucketized_embedding/Reshape/shape"
  attr {
    key: "T"
    value {
      type: DT_INT32
    }
  }
  attr {
    key: "Tshape"
    value {
      type: DT_INT32
    }
  }
}

Jacques Pienaar

unread,
Mar 15, 2021, 9:44:33 AM3/15/21
to mofheka, MLIR
Hey,

Is the shape of arg0 known here? E.g., it is shown as an arg of the function, but is the actual value inferable? (Single call site where shape can be refined). The reshape determines the value for -1 index such that the number of elements are unchanged. It does not mean dynamic size, but "fill in this value" (so multiple -1's are not valid in shape input), so depending on input this is correct final result. You can run the shape inference individually using tf-opt on the module with --debug option (enabled only in debug mode or when UNDEBUG is used) and that will report the steps and process by which the shape was gotten.

Let me know if you verify it is producing the fixed shape even when the arg cannot be known.

-- Jacques

--
You received this message because you are subscribed to the Google Groups "MLIR" group.
To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/702525aa-1703-4938-8d6e-1d34e6fec7e9n%40tensorflow.org.

mofheka

unread,
Mar 15, 2021, 10:13:30 AM3/15/21
to MLIR, jpie...@google.com, MLIR, mofheka
Firstly, the shape of %argX is known here, which is my handmade little test.

“%223 = "tf.Reshape"(%arg0, %222) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>”, this is already a result of shape inference.And the shape of "tf.reshape" output is correctly shown tensor<?xi32>. The problem happens in next stage,  when pass commit “%239 = "tf.Mul"(%223, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>”. The tf-opt thrown errors:

tensor<0xi64>./test.tf_dialect.mlir:13:10: error: 'shape.cstr_broadcastable' op operand #1 must be shape or extent tensor, but got 'tensor<1xi32>'
  %239 = "tf.Mul"(%223, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
         ^
./test.tf_dialect.mlir:13:10: note: see current operation: %8 = "shape.cstr_broadcastable"(%7, %1) : (tensor<?xindex>, tensor<1xi32>) -> !shape.witness

"tf.Mul" got wrong tensor from %223. 

By the way, when I change %223 to something else  tensor<?xi32> input, or change ? in tensor to a specific number, no errors occur.

Jacques Pienaar

unread,
Mar 15, 2021, 10:38:22 AM3/15/21
to mofheka, MLIR
The error occurs when legalizing to MHLO it would seem (you don't report which pass/passed are being run here). This has nothing to do TF dialect. And yes it only affects dynamic cases as that is the only one where verification needs to be done explicitly before we go to MHLO to decouple checking from the arithmetic execution. The issue is probably due to canonicalization pattern that is missing a tensor cast as i32 is not a valid type for constraints, index element type is required there.

If you could file a github issue with reproducer that would be best.

-- Jacques

mofheka

unread,
Mar 15, 2021, 10:38:00 PM3/15/21
to MLIR, jpie...@google.com, MLIR, mofheka
I have filed a github issue for many days, but nobody replies……
https://github.com/tensorflow/tensorflow/issues/47516#event-4434491814

mofheka

unread,
Mar 15, 2021, 10:53:58 PM3/15/21
to MLIR, mofheka, jpie...@google.com, MLIR
But if I run another dynamic shape test, for example:

func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
  %238 = "tf.Mul"(%arg0, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
  %239 = "tf.Mul"(%238, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %239: tensor<?xi32>
}

OR:

func @main(%arg0: tensor<800xi32>, %arg1: tensor<800xi32>) -> tensor<800xi32> {
  %222 = "tf.Const"() {device = "", value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
  %223 = "tf.Reshape"(%arg0, %222) {device = ""} : (tensor<800xi32>, tensor<1xi32>) -> tensor<800xi32>
  %239 = "tf.Mul"(%223, %arg1) {device = ""} : (tensor<800xi32>, tensor<800xi32>) -> tensor<800xi32>
return %239: tensor<800xi32>
}

These above-mentioned examples are also a dynamic shape func, but they go well. So what happen to the result of "tf.Reshape" with dynamic shape? I mean it goes wrong when convert blow example:

func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
  %222 = "tf.Const"() {device = "", value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
  %223 = "tf.Reshape"(%arg0, %222) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
  %239 = "tf.Div"(%223, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %239: tensor<?xi32>
}

Which part of source code could I try to check for fixing this bug?

mofheka

unread,
Mar 15, 2021, 11:00:25 PM3/15/21
to MLIR, jpie...@google.com, MLIR, mofheka
Sorry I missed report which pass I used, it's "tf-opt -tf-to-hlo-pipeline". TF dialect to CHLO dialect, and then CHLO dialect to MHLO dialect. I thought the problem has happened in "DenseIntElementsAttr getBroadcastDimensionsAttr" which is called by "class DirectBinaryPat<Op FromOp, Op ToOp>" in "legalize_tf_patterns.td". But it seems nothing wrong here.
在2021年3月15日星期一 UTC+8 下午10:38:22<jpie...@google.com> 写道:

Jacques Pienaar

unread,
Mar 16, 2021, 12:26:39 AM3/16/21
to mofheka, MLIR
So best way to start is build with debug support

build -c opt tensorflow/compiler/mlir/tf-opt --copt=-UNDEBUG

then run with debugging and printing on

tf-opt --tf-to-hlo-pipeline --debug --print-ir-after-all file.mlir

and then it seems like MHLO dynamic_broadcast_in_dim or Shape's BroadCastOp folder is acting up. I'll check again in the morning, but I'm assuming it is a canonicalization pattern that is too loose.

-- Jacques

mofheka

unread,
Mar 16, 2021, 2:23:11 AM3/16/21
to MLIR, jpie...@google.com, MLIR, mofheka
Args: ./tf-opt -tf-to-hlo-pipeline --debug --print-ir-after-all ./test.tf_dialect.mlir -o ./test.mhlo.mlir
Load new dialect in Context
Load new dialect in Context acc
Load new dialect in Context affine
Load new dialect in Context arm_neon
Load new dialect in Context arm_sve
Load new dialect in Context async
Load new dialect in Context avx512
Load new dialect in Context chlo
Load new dialect in Context complex
Load new dialect in Context gpu
Load new dialect in Context linalg
Load new dialect in Context std
Load new dialect in Context tensor
Load new dialect in Context llvm
Load new dialect in Context llvm_arm_sve
Load new dialect in Context lmhlo
Load new dialect in Context lmhlo_gpu
Load new dialect in Context math
Load new dialect in Context mhlo
Load new dialect in Context nvvm
Load new dialect in Context omp
Load new dialect in Context pdl
Load new dialect in Context pdl_interp
Load new dialect in Context quant
Load new dialect in Context rocdl
Load new dialect in Context scf
Load new dialect in Context sdbm
Load new dialect in Context shape
Load new dialect in Context spv
Load new dialect in Context tf
Load new dialect in Context tf_device
Load new dialect in Context tf_executor
Load new dialect in Context tf_framework
Load new dialect in Context tf_saved_model
Load new dialect in Context tfl
Load new dialect in Context tosa
Load new dialect in Context vector
// *** IR Dump After FunctionalControlFlowToRegionsPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {device = "", value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


2021-03-16 14:22:25.753208: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
// *** IR Dump After Canonicalizer ***
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}

// *** IR Dump After Inliner ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::TF::{anonymous}::DropWhileShapeInvariantPass ***
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}

// *** IR Dump After Canonicalizer ***
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}

// *** IR Dump After SCCP ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::TF::{anonymous}::GuaranteeAllFuncsOneUse ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


Skipping inference; Internal: Missing 'tf.versions' attribute on the module, abort.
// *** IR Dump After TensorFlowShapeInferencePass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After SCCP ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::{anonymous}::TensorListOpsDecompositionPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::{anonymous}::StackOpsDecompositionPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::{anonymous}::TensorArrayOpsDecompositionPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::TFDevice::{anonymous}::DecomposeResourceOps ***
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}

// *** IR Dump After mlir::TF::{anonymous}::PromoteResourcesToArgsPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After SymbolDCE ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


Skipping inference; Internal: Missing 'tf.versions' attribute on the module, abort.
// *** IR Dump After TensorFlowShapeInferencePass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After RegionControlFlowToFunctionalPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::mhlo::{anonymous}::LegalizeTFControlFlow ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}



//===-------------------------------------------===//
Legalizing operation : 'module'(0x557512fd83d0) {
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'func'(0x557512fd79e0) {
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tf.Const'(0x5575130b28f0) {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tf.Reshape'(0x55751305f860) {
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tf.Mul'(0x557513060e30) {
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'std.return'(0x557513062460) {
"std.return"(%2) : (tensor<?xi32>) -> ()

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'module_terminator'(0x557513050e10) {
"module_terminator"() : () -> ()

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
// *** IR Dump After LegalizeTfTypesPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}



//===-------------------------------------------===//
Legalizing operation : 'func'(0x557512fd79e0) {
* Fold {
} -> FAILURE : unable to fold

* Pattern : 'func -> ()' {
} -> FAILURE : pattern failed to match
} -> FAILURE : no matched legalization pattern
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tf.Const'(0x5575130b28f0) {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>

* Fold {
} -> FAILURE : unable to fold

* Pattern : 'tf.Const -> (mhlo.constant, tensor.cast)' {
** Insert : 'mhlo.constant'(0x5575130d0ac0)
** Insert : 'tensor.cast'(0x5575130b0b30)
** Replace : 'tf.Const'(0x5575130b28f0)

//===-------------------------------------------===//
Legalizing operation : 'mhlo.constant'(0x5575130d0ac0) {
%0 = "mhlo.constant"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tensor.cast'(0x5575130b0b30) {
%1 = "tensor.cast"(%0) : (tensor<1xi32>) -> tensor<1xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tf.Reshape'(0x55751305f860) {
%3 = "tf.Reshape"(%arg0, %2) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>

* Fold {
} -> FAILURE : unable to fold

* Pattern : 'tf.Reshape -> ()' {
** Insert : 'mhlo.dynamic_reshape'(0x5575130d08c0)
** Replace : 'tf.Reshape'(0x55751305f860)

//===-------------------------------------------===//
Legalizing operation : 'mhlo.dynamic_reshape'(0x5575130d08c0) {
%3 = "mhlo.dynamic_reshape"(%arg0, %2) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tf.Mul'(0x557513060e30) {
%5 = "tf.Mul"(%4, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>

* Fold {
} -> FAILURE : unable to fold

* Pattern : 'tf.Mul -> (chlo.broadcast_multiply)' {
** Insert : 'chlo.broadcast_multiply'(0x5575130d09a0)
** Replace : 'tf.Mul'(0x557513060e30)

//===-------------------------------------------===//
Legalizing operation : 'chlo.broadcast_multiply'(0x5575130d09a0) {
%5 = "chlo.broadcast_multiply"(%4, %arg1) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>

* Fold {
} -> FAILURE : unable to fold

* Pattern : 'chlo.broadcast_multiply -> ()' {
} -> FAILURE : pattern failed to match

* Pattern : 'chlo.broadcast_multiply -> ()' {
** Insert : 'shape.shape_of'(0x5575130d67c0)
** Insert : 'shape.shape_of'(0x5575130d68a0)
** Insert : 'shape.cstr_broadcastable'(0x5575130b24e0)
** Insert : 'shape.assuming'(0x55751306a960)
** Insert : 'shape.shape_of'(0x5575130b25e0)
** Insert : 'shape.shape_of'(0x5575130b2700)
** Insert : 'shape.broadcast'(0x5575130b2790)
** Insert : 'tensor.cast'(0x5575130b2840)
** Insert : 'mhlo.dynamic_broadcast_in_dim'(0x5575130d1e50)
** Insert : 'mhlo.dynamic_broadcast_in_dim'(0x5575130d1f00)
** Insert : 'mhlo.multiply'(0x5575130d58c0)
** Insert : 'shape.assuming_yield'(0x557513092440)
** Replace : 'chlo.broadcast_multiply'(0x5575130d09a0)

//===-------------------------------------------===//
Legalizing operation : 'shape.shape_of'(0x5575130d67c0) {
%5 = "shape.shape_of"(%3) : (tensor<?xi32>) -> tensor<?xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.shape_of'(0x5575130d68a0) {
%6 = "shape.shape_of"(%arg1) : (tensor<?xi32>) -> tensor<?xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.cstr_broadcastable'(0x5575130b24e0) {
%7 = "shape.cstr_broadcastable"(%5, %6) : (tensor<?xindex>, tensor<?xindex>) -> !shape.witness

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.assuming'(0x55751306a960) {
} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.shape_of'(0x5575130b25e0) {
%11 = "shape.shape_of"(%3) : (tensor<?xi32>) -> tensor<?xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.shape_of'(0x5575130b2700) {
%12 = "shape.shape_of"(%arg1) : (tensor<?xi32>) -> tensor<?xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.broadcast'(0x5575130b2790) {
%13 = "shape.broadcast"(%11, %12) : (tensor<?xindex>, tensor<?xindex>) -> tensor<?xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'tensor.cast'(0x5575130b2840) {
%14 = "tensor.cast"(%13) : (tensor<?xindex>) -> tensor<1xindex>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'mhlo.dynamic_broadcast_in_dim'(0x5575130d1e50) {
%15 = "mhlo.dynamic_broadcast_in_dim"(%3, %14) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'mhlo.dynamic_broadcast_in_dim'(0x5575130d1f00) {
%16 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %14) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'mhlo.multiply'(0x5575130d58c0) {
%17 = "mhlo.multiply"(%15, %16) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'shape.assuming_yield'(0x557513092440) {
"shape.assuming_yield"(%17) : (tensor<?xi32>) -> ()

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//
} -> SUCCESS : pattern applied successfully
} -> SUCCESS
//===-------------------------------------------===//

//===-------------------------------------------===//
Legalizing operation : 'std.return'(0x557513062460) {
"std.return"(%10) : (tensor<?xi32>) -> ()

} -> SUCCESS : operation marked legal by the target
//===-------------------------------------------===//
// *** IR Dump After mlir::mhlo::{anonymous}::LegalizeTF ***
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = mhlo.constant dense<-1> : tensor<1xi32>
%1 = tensor.cast %0 : tensor<1xi32> to tensor<1xi32>
%2 = "mhlo.dynamic_reshape"(%arg0, %1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%3 = shape.shape_of %2 : tensor<?xi32> -> tensor<?xindex>
%4 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<?xindex>
%5 = shape.cstr_broadcastable %3, %4 : tensor<?xindex>, tensor<?xindex>
%6 = shape.assuming %5 -> (tensor<?xi32>) {
%7 = shape.shape_of %2 : tensor<?xi32> -> tensor<?xindex>
%8 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<?xindex>
%9 = shape.broadcast %7, %8 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
%10 = tensor.cast %9 : tensor<?xindex> to tensor<1xindex>
%11 = "mhlo.dynamic_broadcast_in_dim"(%2, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%13 = mhlo.multiply %11, %12 : tensor<?xi32>
shape.assuming_yield %13 : tensor<?xi32>
}
return %6 : tensor<?xi32>
}

// *** IR Dump After mlir::mhlo::{anonymous}::LegalizeTFCommunication ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = mhlo.constant dense<-1> : tensor<1xi32>
%1 = tensor.cast %0 : tensor<1xi32> to tensor<1xi32>
%2 = "mhlo.dynamic_reshape"(%arg0, %1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%3 = shape.shape_of %2 : tensor<?xi32> -> tensor<?xindex>
%4 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<?xindex>
%5 = shape.cstr_broadcastable %3, %4 : tensor<?xindex>, tensor<?xindex>
%6 = shape.assuming %5 -> (tensor<?xi32>) {
%7 = shape.shape_of %2 : tensor<?xi32> -> tensor<?xindex>
%8 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<?xindex>
%9 = shape.broadcast %7, %8 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
%10 = tensor.cast %9 : tensor<?xindex> to tensor<1xindex>
%11 = "mhlo.dynamic_broadcast_in_dim"(%2, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%12 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %10) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%13 = mhlo.multiply %11, %12 : tensor<?xi32>
shape.assuming_yield %13 : tensor<?xi32>
}
return %6 : tensor<?xi32>
}
}


tf-opt: external/llvm-project/mlir/include/mlir/IR/BuiltinAttributes.h:581: llvm::iterator_range<mlir::DenseElementsAttr::ElementIterator<T> > mlir::DenseElementsAttr::getValues() const [with T = long int; <template-parameter-1-2> = void]: Assertion `isValidIntOrFloat(sizeof(T), std::numeric_limits<T>::is_integer, std::numeric_limits<T>::is_signed)' failed.
TensorFlow crashed, please file a bug on https://github.com/tensorflow/tensorflow/issues with the trace below.
Stack dump:
0. Program arguments: ./tf-opt -tf-to-hlo-pipeline --debug --print-ir-after-all ./test.tf_dialect.mlir -o ./test.mhlo.mlir
1. Program arguments: ./tf-opt -tf-to-hlo-pipeline --debug --print-ir-after-all ./test.tf_dialect.mlir -o ./test.mhlo.mlir
2. Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
./tf-opt(+0xaedf1e3)[0x55750b0ff1e3]
./tf-opt(+0xaedd32d)[0x55750b0fd32d]
./tf-opt(+0xaedd4bc)[0x55750b0fd4bc]
/lib/x86_64-linux-gnu/libpthread.so.0(+0x12980)[0x7f9b597cf980]
/lib/x86_64-linux-gnu/libc.so.6(gsignal+0xc7)[0x7f9b591f2fb7]
/lib/x86_64-linux-gnu/libc.so.6(abort+0x141)[0x7f9b591f4921]
/lib/x86_64-linux-gnu/libc.so.6(+0x3048a)[0x7f9b591e448a]
/lib/x86_64-linux-gnu/libc.so.6(+0x30502)[0x7f9b591e4502]
./tf-opt(+0x8287d03)[0x5575084a7d03]
./tf-opt(+0x82a8c02)[0x5575084c8c02]
./tf-opt(+0xa9c3e3a)[0x55750abe3e3a]
./tf-opt(+0xa9c7ac5)[0x55750abe7ac5]
./tf-opt(+0xadec1f9)[0x55750b00c1f9]
./tf-opt(+0xaa859ec)[0x55750aca59ec]
./tf-opt(+0xaa865fd)[0x55750aca65fd]
./tf-opt(+0xaa80a3a)[0x55750aca0a3a]
./tf-opt(+0xad46fa3)[0x55750af66fa3]
./tf-opt(+0xad47662)[0x55750af67662]
./tf-opt(+0xad47fc8)[0x55750af67fc8]
./tf-opt(+0xad464b1)[0x55750af664b1]
./tf-opt(+0xad4739f)[0x55750af6739f]
./tf-opt(+0xad47662)[0x55750af67662]
./tf-opt(+0xad49c57)[0x55750af69c57]
./tf-opt(+0x850cc53)[0x55750872cc53]
./tf-opt(+0x850cfab)[0x55750872cfab]
./tf-opt(+0x850d1b3)[0x55750872d1b3]
./tf-opt(+0x850ddfd)[0x55750872ddfd]
./tf-opt(+0xbf1504)[0x557500e11504]
/lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xe7)[0x7f9b591d5bf7]
./tf-opt(+0xd581aa)[0x557500f781aa]
Stack dump without symbol names (ensure you have llvm-symbolizer in your PATH or set the environment var `LLVM_SYMBOLIZER_PATH` to point to it):
./tf-opt(+0xaedf1e3)[0x55750b0ff1e3]
./tf-opt(+0xaedd32d)[0x55750b0fd32d]
./tf-opt(+0xaedd4bc)[0x55750b0fd4bc]
/lib/x86_64-linux-gnu/libpthread.so.0(+0x12980)[0x7f9b597cf980]
/lib/x86_64-linux-gnu/libc.so.6(gsignal+0xc7)[0x7f9b591f2fb7]
/lib/x86_64-linux-gnu/libc.so.6(abort+0x141)[0x7f9b591f4921]
/lib/x86_64-linux-gnu/libc.so.6(+0x3048a)[0x7f9b591e448a]
/lib/x86_64-linux-gnu/libc.so.6(+0x30502)[0x7f9b591e4502]
./tf-opt(+0x8287d03)[0x5575084a7d03]
./tf-opt(+0x82a8c02)[0x5575084c8c02]
./tf-opt(+0xa9c3e3a)[0x55750abe3e3a]
./tf-opt(+0xa9c7ac5)[0x55750abe7ac5]
./tf-opt(+0xadec1f9)[0x55750b00c1f9]
./tf-opt(+0xaa859ec)[0x55750aca59ec]
./tf-opt(+0xaa865fd)[0x55750aca65fd]
./tf-opt(+0xaa80a3a)[0x55750aca0a3a]
./tf-opt(+0xad46fa3)[0x55750af66fa3]
./tf-opt(+0xad47662)[0x55750af67662]
./tf-opt(+0xad47fc8)[0x55750af67fc8]
./tf-opt(+0xad464b1)[0x55750af664b1]
./tf-opt(+0xad4739f)[0x55750af6739f]
./tf-opt(+0xad47662)[0x55750af67662]
./tf-opt(+0xad49c57)[0x55750af69c57]
./tf-opt(+0x850cc53)[0x55750872cc53]
./tf-opt(+0x850cfab)[0x55750872cfab]
./tf-opt(+0x850d1b3)[0x55750872d1b3]
./tf-opt(+0x850ddfd)[0x55750872ddfd]
./tf-opt(+0xbf1504)[0x557500e11504]
/lib/x86_64-linux-gnu/libc.so.6(__libc_start_main+0xe7)[0x7f9b591d5bf7]
./tf-opt(+0xd581aa)[0x557500f781aa]

mofheka

unread,
Mar 16, 2021, 4:22:53 AM3/16/21
to MLIR, mofheka, jpie...@google.com, MLIR
for "--debug" would cause a core dump, so there is a printed output without "--debug" parameter:

// *** IR Dump After FunctionalControlFlowToRegionsPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {device = "", value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


2021-03-16 16:18:48.229533: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
// *** IR Dump After TensorFlowShapeInferencePass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After RegionControlFlowToFunctionalPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After mlir::mhlo::{anonymous}::LegalizeTFControlFlow ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After LegalizeTfTypesPass ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Reshape"(%arg0, %0) {device = ""} : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%2 = "tf.Mul"(%1, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %2 : tensor<?xi32>
}
}


// *** IR Dump After Canonicalizer Failed ***
"func"() ( {
^bb0(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>): // no predecessors
%0 = "mhlo.constant"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "shape.shape_of"(%arg1) : (tensor<?xi32>) -> tensor<?xindex>
%2 = "shape.cstr_broadcastable"(%1, %0) : (tensor<?xindex>, tensor<1xi32>) -> !shape.witness
%3 = "shape.assuming"(%2) ( {
%4 = "shape.shape_of"(%arg1) : (tensor<?xi32>) -> tensor<?xindex>
%5 = "shape.broadcast"(%4, %0) : (tensor<?xindex>, tensor<1xi32>) -> tensor<?xindex>
%6 = "tensor.cast"(%5) : (tensor<?xindex>) -> tensor<1xindex>
%7 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %6) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%8 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %6) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%9 = "mhlo.multiply"(%7, %8) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
"shape.assuming_yield"(%9) : (tensor<?xi32>) -> ()
}) : (!shape.witness) -> tensor<?xi32>
"std.return"(%3) : (tensor<?xi32>) -> ()
}) {sym_name = "main", type = (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>} : () -> ()

./test.tf_dialect.mlir:5:10: error: 'shape.cstr_broadcastable' op operand #1 must be shape or extent tensor, but got 'tensor<1xi32>'
%239 = "tf.Mul"(%223, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
^
./test.tf_dialect.mlir:5:10: note: see current operation: %2 = "shape.cstr_broadcastable"(%1, %0) : (tensor<?xindex>, tensor<1xi32>) -> !shape.witness

Jacques Pienaar

unread,
Mar 16, 2021, 7:48:12 AM3/16/21
to mofheka, MLIR
And there you have the cause as suspected in canonicalization:

%0 = mhlo.constant dense<-1> : tensor<1xi32>
%1 = tensor.cast %0 : tensor<1xi32> to tensor<1xi32>
%2 = "mhlo.dynamic_reshape"(%arg0, %1) : (tensor<?xi32>, tensor<1xi32>) -> tensor<?xi32>
%3 = shape.shape_of %2 : tensor<?xi32> -> tensor<?xindex>

Message has been deleted
Message has been deleted

Jacques Pienaar

unread,
Mar 17, 2021, 8:15:51 AM3/17/21
to mofheka, MLIR
I submitted a change yesterday to avoid this canonicalization when it changes types, that should have resolved this too. Could you sync to head and try?

-- Jacques 

On Wed, Mar 17, 2021, 2:35 AM mofheka <mofh...@gmail.com> wrote:
I deleted ShapeOfDynamicReshape pattern in hlo_ops.cc and hlo_patterns.td directly, and it works…… What's a better way to fix this problem? 

在2021年3月17日星期三 UTC+8 下午3:21:40<mofheka> 写道:
It seems that the problem not only occurred in "Reshape", also somewhere else. When I transformed a func like this:

func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
  %223 = "tf.Reshape"(%arg0, %arg0) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
  %239 = "tf.Mul"(%223, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
return %239: tensor<?xi32>
}

IR Dump printed as blow. In the last IR dump, %arg0 did not receive a "shape.shape_of" operand, and %0(the result of reshape?) seemed having wright shape inference.

// *** IR Dump After mlir::mhlo::{anonymous}::LegalizeTFCommunication ***
module {
func @main(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>) -> tensor<?xi32> {
%0 = "mhlo.dynamic_reshape"(%arg0, %arg0) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
%1 = shape.shape_of %0 : tensor<?xi32> -> tensor<?xindex>
%2 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<?xindex>
%3 = shape.cstr_broadcastable %1, %2 : tensor<?xindex>, tensor<?xindex>
%4 = shape.assuming %3 -> (tensor<?xi32>) {
%5 = shape.shape_of %0 : tensor<?xi32> -> tensor<?xindex>
%6 = shape.shape_of %arg1 : tensor<?xi32> -> tensor<?xindex>
%7 = shape.broadcast %5, %6 : tensor<?xindex>, tensor<?xindex> -> tensor<?xindex>
%8 = tensor.cast %7 : tensor<?xindex> to tensor<1xindex>
%9 = "mhlo.dynamic_broadcast_in_dim"(%0, %8) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%10 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %8) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%11 = mhlo.multiply %9, %10 : tensor<?xi32>
shape.assuming_yield %11 : tensor<?xi32>
}
return %4 : tensor<?xi32>
}
}


// *** IR Dump After Canonicalizer Failed ***
"func"() ( {
^bb0(%arg0: tensor<?xi32>, %arg1: tensor<?xi32>): // no predecessors
%0 = "shape.shape_of"(%arg1) : (tensor<?xi32>) -> tensor<?xindex>
%1 = "shape.cstr_broadcastable"(%arg0, %0) : (tensor<?xi32>, tensor<?xindex>) -> !shape.witness
%2 = "shape.assuming"(%1) ( {
%3 = "shape.shape_of"(%arg1) : (tensor<?xi32>) -> tensor<?xindex>
%4 = "shape.broadcast"(%arg0, %3) : (tensor<?xi32>, tensor<?xindex>) -> tensor<?xindex>
%5 = "tensor.cast"(%4) : (tensor<?xindex>) -> tensor<1xindex>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %5) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%7 = "mhlo.dynamic_broadcast_in_dim"(%arg1, %5) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xi32>, tensor<1xindex>) -> tensor<?xi32>
%8 = "mhlo.multiply"(%6, %7) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
"shape.assuming_yield"(%8) : (tensor<?xi32>) -> ()
}) : (!shape.witness) -> tensor<?xi32>
"std.return"(%2) : (tensor<?xi32>) -> ()
}) {sym_name = "main", type = (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>} : () -> ()

./test.tf_dialect.mlir:5:10: error: 'shape.cstr_broadcastable' op operand #0 must be shape or extent tensor, but got 'tensor<?xi32>'
%239 = "tf.Mul"(%223, %arg1) {device = ""} : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>
^
./test.tf_dialect.mlir:5:10: note: see current operation: %1 = "shape.cstr_broadcastable"(%arg0, %0) : (tensor<?xi32>, tensor<?xindex>) -> !shape.witness

mofheka

unread,
Mar 18, 2021, 4:04:11 AM3/18/21
to MLIR, jpie...@google.com, MLIR, mofheka

Thanks a lot! It works!
Reply all
Reply to author
Forward
0 new messages