Skip to content

Commit bef3dd3

Browse files
andylytensorflower-gardener
authored andcommitted
Update verifier for tf.While/tf.WhileRegion to not check for operand result type compatibility if op is shape invariant.
Shape invariant tf.While/tf.WhileRegion can have different result shapes from operand shapes. PiperOrigin-RevId: 344929917 Change-Id: Ic160211075c0bb91548c71fec9f2d880ccd79723
1 parent 35bc546 commit bef3dd3

File tree

2 files changed

+68
-25
lines changed

2 files changed

+68
-25
lines changed

tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2534,21 +2534,20 @@ OpFoldResult VariableShapeOp::fold(ArrayRef<Attribute> operands) {
25342534

25352535
static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
25362536
TypeRange body_input,
2537-
TypeRange body_result) {
2538-
// Collect all the type lists for the op so that different pairs of type lists
2539-
// can be compared for the compatibility.
2540-
constexpr int kNumTypeLists = 5;
2541-
const std::array<TypeRangeWithDesc, kNumTypeLists> type_lists = {{
2542-
{op->getOperandTypes(), "input"},
2537+
TypeRange body_result,
2538+
bool shape_invariant) {
2539+
const TypeRangeWithDesc input_type = {op->getOperandTypes(), "input"};
2540+
const TypeRangeWithDesc result_type = {op->getResultTypes(), "result"};
2541+
constexpr int kNumRegionTypeLists = 3;
2542+
const std::array<TypeRangeWithDesc, kNumRegionTypeLists> region_types = {{
25432543
{body_result, "body result"},
2544-
{op->getResultTypes(), "result"},
25452544
{cond_input, "condition input"},
25462545
{body_input, "body input"},
25472546
}};
25482547

25492548
// A pair of type lists should be cast compatible with each other if one is
25502549
// converted to the another for a function call or assignment or there is a
2551-
// common source of inputs for both. Therefore, the While op requires the
2550+
// common source of inputs for both. Therefore, the While op requires the
25522551
// following pairs of type lists to be cast compatible for the tensor_cast
25532552
// operation:
25542553
//
@@ -2557,7 +2556,8 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
25572556
// * Operands and body inputs to call the body function for the first
25582557
// iteration if the cond functions returns True or equivalent result.
25592558
// * Operands and results to assign cond function arguments to op results if
2560-
// the cond function returns False or equivalent result.
2559+
// the cond function returns False or equivalent result. If the op is shape
2560+
// invariant, this does not hold as shapes can differ.
25612561
// * All three pairs using cond inputs, body inputs and results as operand is
25622562
// a common source for all three.
25632563
// * Body result and cond inputs to call the cond function for the subsequent
@@ -2566,17 +2566,28 @@ static LogicalResult VerifyWhileTypes(Operation *op, TypeRange cond_input,
25662566
//
25672567
// Note that the operands and body results need not be compatible as they are
25682568
// never converted from one to the another nor there is a common source
2569-
// tensors. Compatibility requirement is not transitive.
2570-
2571-
for (int i = 0; i < kNumTypeLists; ++i) {
2572-
// Skip the first pair as the While op operands and body function results
2573-
// does not need to be compatible with each other.
2574-
for (int j = std::max(2, i + 1); j < kNumTypeLists; ++j) {
2575-
auto &a = type_lists[i];
2576-
auto &b = type_lists[j];
2577-
if (failed(VerifyTypeRangesAreCompatible(op, a, b))) return failure();
2578-
}
2579-
}
2569+
// tensors. Compatibility requirement is not transitive.
2570+
2571+
if (!shape_invariant &&
2572+
failed(VerifyTypeRangesAreCompatible(op, input_type, result_type)))
2573+
return failure();
2574+
2575+
// Skip the first pair as the While op operands and body function results does
2576+
// not need to be compatible with each other.
2577+
for (int i = 1; i < kNumRegionTypeLists; ++i)
2578+
if (failed(VerifyTypeRangesAreCompatible(op, input_type, region_types[i])))
2579+
return failure();
2580+
2581+
for (int i = 0; i < kNumRegionTypeLists; ++i)
2582+
if (failed(VerifyTypeRangesAreCompatible(op, result_type, region_types[i])))
2583+
return failure();
2584+
2585+
for (int i = 0; i < kNumRegionTypeLists; ++i)
2586+
for (int j = i + 1; j < kNumRegionTypeLists; ++j)
2587+
if (failed(VerifyTypeRangesAreCompatible(op, region_types[i],
2588+
region_types[j])))
2589+
return failure();
2590+
25802591
return success();
25812592
}
25822593

@@ -2601,7 +2612,8 @@ static LogicalResult Verify(WhileOp op) {
26012612

26022613
if (failed(VerifyWhileTypes(op, /*cond_input=*/cond_fn_type.getInputs(),
26032614
/*body_input=*/body_fn_type.getInputs(),
2604-
/*body_result=*/body_fn_type.getResults())))
2615+
/*body_result=*/body_fn_type.getResults(),
2616+
op.shape_invariant())))
26052617
return failure();
26062618
return success();
26072619
}
@@ -2626,7 +2638,8 @@ static LogicalResult Verify(WhileRegionOp op) {
26262638
Operation *body_yield = op.body().front().getTerminator();
26272639
if (failed(VerifyWhileTypes(op, /*cond_input=*/op.cond().getArgumentTypes(),
26282640
/*body_input=*/op.body().getArgumentTypes(),
2629-
/*body_result=*/body_yield->getOperandTypes())))
2641+
/*body_result=*/body_yield->getOperandTypes(),
2642+
op.shape_invariant())))
26302643
return failure();
26312644
return success();
26322645
}

tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1849,7 +1849,7 @@ func private @testWhileBody(tensor<*xf32>) -> (tensor<*xi32>)
18491849
// Test invalid 'While' operation
18501850
func @testWhileResult(tensor<*xf32>) -> (tensor<*xf32>) {
18511851
^bb0(%arg0: tensor<*xf32>):
1852-
// expected-error @+1 {{'tf.While' op body result type tensor<*xi32> is incompatible with result type tensor<*xf32> at index 0}}
1852+
// expected-error @+1 {{'tf.While' op result type tensor<*xf32> is incompatible with body result type tensor<*xi32> at index 0}}
18531853
%1 = "tf.While"(%arg0) {
18541854
cond = @testWhileCond,
18551855
body = @testWhileBody,
@@ -1933,6 +1933,19 @@ func @testWhileResult(tensor<*x!tf.resource<tensor<32xf32>>>) -> (tensor<!tf.res
19331933
return %1 : tensor<!tf.resource>
19341934
}
19351935

1936+
// -----
1937+
1938+
func private @cond(tensor<1x?x3xf32>) -> tensor<i1>
1939+
func private @body(tensor<1x?x3xf32>) -> tensor<1x?x3xf32>
1940+
1941+
// Test shape invariant 'While' operation verifier with different operand and
1942+
// result shapes.
1943+
// CHECK-LABEL: func @testShapeInvariantWhile
1944+
func @testShapeInvariantWhile(%arg0: tensor<1x2x3xf32>) -> tensor<1x8x3xf32> {
1945+
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false, shape_invariant} : (tensor<1x2x3xf32>) -> tensor<1x8x3xf32>
1946+
return %0 : tensor<1x8x3xf32>
1947+
}
1948+
19361949
// -----
19371950
// WhileRegion tests
19381951

@@ -2084,7 +2097,7 @@ func @testInvalidWhileRegion_I_BI_TypeMismatch(%arg0 : tensor<i32>) -> (tensor<i
20842097
// -----
20852098

20862099
func @testInvalidWhileRegion_O_BO_CountMismatch(%arg0 : tensor<i32>) -> (tensor<i32>) {
2087-
// expected-error @+1 {{'tf.WhileRegion' op body results (size = 2) should have the same number of values as results (size = 1)}}
2100+
// expected-error @+1 {{'tf.WhileRegion' op results (size = 1) should have the same number of values as body results (size = 2)}}
20882101
%0 = "tf.WhileRegion"(%arg0) (
20892102
{
20902103
^bb0(%carg: tensor<i32>):
@@ -2102,7 +2115,7 @@ func @testInvalidWhileRegion_O_BO_CountMismatch(%arg0 : tensor<i32>) -> (tensor<
21022115
// -----
21032116

21042117
func @testInvalidWhileRegionMismatch_O_BO_TypeMismatch(%arg0 : tensor<i32>, %arg1: tensor<f32>) -> (tensor<i32>) {
2105-
// expected-error @+1 {{'tf.WhileRegion' op body result type tensor<f32> is incompatible with result type tensor<i32> at index 0}}
2118+
// expected-error @+1 {{'tf.WhileRegion' op result type tensor<i32> is incompatible with body result type tensor<f32> at index 0}}
21062119
%0 = "tf.WhileRegion"(%arg0) (
21072120
{
21082121
^bb0(%carg: tensor<i32>):
@@ -2207,6 +2220,23 @@ func @testInvalidWhileRegionConditionOutputType(%arg : tensor<i32>) -> (tensor<i
22072220
return %0 : tensor<i32>
22082221
}
22092222

2223+
// -----
2224+
2225+
// Test shape invariant 'WhileRegion' operation verifier with different operand
2226+
// and result shapes.
2227+
// CHECK-LABEL: func @testShapeInvariantWhileRegion
2228+
func @testShapeInvariantWhileRegion(%arg0: tensor<1x2x3xf32>) -> tensor<1x8x3xf32> {
2229+
%0 = "tf.WhileRegion"(%arg0) ( {
2230+
^cond(%carg0: tensor<1x?x3xf32>):
2231+
%1 = "tf.SomeCondOp"(%carg0) : (tensor<1x?x3xf32>) -> tensor<i1>
2232+
"tf.Yield"(%1) : (tensor<i1>) -> ()
2233+
}, {
2234+
^body(%barg0: tensor<1x?x3xf32>):
2235+
%1 = "tf.SomeBodyOp"(%barg0) : (tensor<1x?x3xf32>) -> tensor<1x?x3xf32>
2236+
"tf.Yield"(%1) : (tensor<1x?x3xf32>) -> ()
2237+
}) {is_stateless = false, shape_invariant} : (tensor<1x2x3xf32>) -> tensor<1x8x3xf32>
2238+
return %0 : tensor<1x8x3xf32>
2239+
}
22102240

22112241
// -----
22122242

0 commit comments

Comments
 (0)