Skip to content

Commit 6c7b8ae

Browse files
tpopptensorflower-gardener
authored andcommitted
Fold away shape.shape_of(mhlo.dynamic_reshape(inp, shape))
This specific pattern can be replaced with the shape passed to dynamic_reshape. This is implemented as a canonicalization on mhlo.dynamic_reshape to fit in the infrastructure of canonicalization. PiperOrigin-RevId: 342009365 Change-Id: I252334c3e6225b984571cd0d1cddef06f1e55dd1
1 parent 407127e commit 6c7b8ae

File tree

3 files changed

+15
-1
lines changed

3 files changed

+15
-1
lines changed

tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1268,7 +1268,8 @@ class DynamicReshapeOpNotActuallyDynamic
12681268

12691269
void DynamicReshapeOp::getCanonicalizationPatterns(
12701270
OwningRewritePatternList& results, MLIRContext* context) {
1271-
results.insert<DynamicReshapeOpNotActuallyDynamic>(context);
1271+
results.insert<DynamicReshapeOpNotActuallyDynamic, ShapeOfDynamicReshape>(
1272+
context);
12721273
}
12731274

12741275
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_patterns.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,6 @@ def DynamicBroadcastToOwnShape_2 : Pat<
2828
(HLO_DynamicBroadcastInDimOp:$op $x, (Shape_ShapeOfOp $x), $attr),
2929
(replaceWithValue $x)>;
3030

31+
def ShapeOfDynamicReshape : Pat<
32+
(Shape_ShapeOfOp (HLO_DynamicReshapeOp $x, $shape)),
33+
(replaceWithValue $shape)>;

tensorflow/compiler/mlir/hlo/tests/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,16 @@ func @dynamic_reshape_not_actually_dynamic(%arg0: tensor<4xf32>, %shape: tensor<
575575
return %0 : tensor<4x1xf32>
576576
}
577577

578+
// CHECK-LABEL: func @shape_of_dynamic_reshape
579+
// CHECK-SAME: [[ARG0:%[a-zA-Z0-9]+]]
580+
// CHECK-SAME: [[ARG1:%[a-zA-Z0-9]+]]
581+
func @shape_of_dynamic_reshape(%arg0: tensor<*xf32>, %shape: tensor<2xindex>) -> tensor<2xindex> {
582+
// CHECK: return [[ARG1]]
583+
%0 = "mhlo.dynamic_reshape"(%arg0, %shape) : (tensor<*xf32>, tensor<2xindex>) -> tensor<?x?xf32>
584+
%1 = shape.shape_of %0 : tensor<?x?xf32> -> tensor<2xindex>
585+
return %1 : tensor<2xindex>
586+
}
587+
578588
// CHECK-LABEL: do_not_dce_while_with_outfeed
579589
func @do_not_dce_while_with_outfeed(%arg0: tensor<i64>) -> tensor<i64> {
580590
// CHECK: mhlo.while

0 commit comments

Comments
 (0)