Why the pass would generate a "tf.Cast" following tensor.cast after processed by Canonicalizer Pass ?

26 views
Skip to first unread message

mofheka

unread,
Mar 18, 2021, 4:11:38 AM3/18/21
to MLIR
Before(IR Dump After TensorFlowShapeInferencePass):

%417 = "mhlo.slice"(%226) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%418 = "mhlo.reshape"(%417) : (tensor<1xi32>) -> tensor<i32>
%419 = "tf.Tile"(%96, %417) {device = ""} : (tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
%420 = shape.shape_of %419 : tensor<?xi32> -> tensor<?xindex>
%421 = tensor.cast %420 : tensor<?xindex> to tensor<1xindex>
%422 = "mhlo.dynamic_broadcast_in_dim"(%7, %421) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xindex>) -> tensor<?xi32>
%423 = mhlo.multiply %419, %422 : tensor<?xi32>

After(IR Dump After TensorFlowShapeInferencePass Failed):

%417 = "mhlo.slice"(%226) {limit_indices = dense<1> : tensor<1xi64>, start_indices = dense<0> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} : (tensor<2xi32>) -> tensor<1xi32>
%418 = "mhlo.reshape"(%417) : (tensor<1xi32>) -> tensor<i32>
%419 = "tf.Tile"(%96, %417) {device = ""} : (tensor<1xi32>, tensor<1xi32>) -> tensor<?xi32>
%420 = "shape.shape_of"(%419) : (tensor<?xi32>) -> tensor<?xindex>
%421 = "tensor.cast"(%420) : (tensor<?xindex>) -> tensor<?xindex>
%422 = "tf.Cast"(%421) {Truncate = false} : (tensor<?xindex>) -> tensor<1xindex>
%423 = "mhlo.dynamic_broadcast_in_dim"(%7, %422) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<i32>, tensor<1xindex>) -> tensor<?xi32>
%424 = "mhlo.multiply"(%419, %423) : (tensor<?xi32>, tensor<?xi32>) -> tensor<?xi32>

It's a pity that I can't recurrent this feature in a small func even using the same op. So I have to cut out them from a big model print log.

Jacques Pienaar

unread,
Mar 18, 2021, 9:19:44 AM3/18/21
to mofheka, MLIR
This will be canonicalized away/that tf.Cast is invalid and pending change to update it.

-- Jacques

--
You received this message because you are subscribed to the Google Groups "MLIR" group.
To unsubscribe from this group and stop receiving emails from it, send an email to mlir+uns...@tensorflow.org.
To view this discussion on the web visit https://groups.google.com/a/tensorflow.org/d/msgid/mlir/f2f56da9-7455-4c8e-ad95-62a6dd88bbccn%40tensorflow.org.

mofheka

unread,
Mar 18, 2021, 10:01:47 PM3/18/21
to MLIR, jpie...@google.com, MLIR, mofheka
When meet this “tf.Cast” which is out of thin air, the pass just stop and throw an error.

./tem_graphdef/tem_new_graph_def.tf_dialect.mlir:241:12: error: 'tf.Cast' op operand #0 must be tensor of tf.dtype values, but got 'tensor<?xindex>'
%238 = "tf.Mul"(%226, %237) {device = ""} : (tensor<i32>, tensor<?xi32>) -> tensor<?xi32>
^
./tem_graphdef/tem_new_graph_def.tf_dialect.mlir:241:12: note: see current operation: %422 = "tf.Cast"(%421) {Truncate = false} : (tensor<?xindex>) -> tensor<1xindex>

Jacques Pienaar

unread,
Mar 19, 2021, 4:10:37 PM3/19/21
to mofheka, MLIR
What happened here is that the op that produces the tensor of index was marked as a pass through op and fed into an op for which its type cannot be refined. As the passthrough op's type changed here a cast needs to be inserted so that the input type to the operations whose type the pass does not know how to refine does not get changed. But index is not a supported dtype in TensorFlow and so the tf.Cast operator currently (patch pending) blindly inserted causes an error.

-- Jacques
Reply all
Reply to author
Forward
0 new messages