Skip to content

Commit 919dfc3

Browse files
andylytensorflower-gardener
authored andcommitted
Fix segfault caused by passing in a block argument into DefinedByConv2D constraint check.
PiperOrigin-RevId: 278115186 Change-Id: I8b40a2804cf7e84aacfd993855e751c75462276b
1 parent b99ee9b commit 919dfc3

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// RUN: tf-opt -tf-optimize %s | FileCheck %s --dump-input=fail
2+
3+
// Check passing an argument into DefinedByConv2D constraint does not crash.
4+
5+
// CHECK-LABEL: func @main
6+
func @main(%arg0: tensor<1xf32>) -> tensor<1xf32>
7+
attributes {tf.entry_function = {inputs = "input", outputs = "output_node"}} {
8+
%0 = constant dense<2.000000e+00> : tensor<f32>
9+
%1 = constant dense<1.000000e+00> : tensor<f32>
10+
%2 = "tf.AddV2"(%arg0, %1) {T = "tfdtype$DT_FLOAT", device = "", name = "StatefulPartitionedCall/add"} : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
11+
%3 = "tf.Mul"(%2, %0) {T = "tfdtype$DT_FLOAT", device = "", name = "output_node"} : (tensor<1xf32>, tensor<f32>) -> tensor<1xf32>
12+
return %3 : tensor<1xf32>
13+
}

tensorflow/compiler/mlir/tensorflow/transforms/optimize.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def BroadcastableElements :
2121
Constraint<CPred<"TFL::IsBroadcastableElementsAttrs($0, $1)">>;
2222
def F32ElementsAttr : ElementsAttrBase<
2323
CPred<"$_self.cast<ElementsAttr>().getType().getElementType().isF32()">, "float constant tensor">;
24-
def DefinedByConv2D : Constraint<CPred<"llvm::isa<mlir::TF::Conv2DOp>($0->getDefiningOp())">>;
24+
def DefinedByConv2D : Constraint<CPred<"llvm::isa_and_nonnull<mlir::TF::Conv2DOp>($0->getDefiningOp())">>;
2525

2626
// If we see a Conv2D op followed by Mul, then multiply the filter
2727
// with the value in Mul.

0 commit comments

Comments
 (0)