Skip to content

Commit e5ea20d

Browse files
liufengdbtensorflower-gardener
authored andcommitted
Add the support of getting the length of tensor list
This patch added both the op definition and the Python translation to support getting the length of tensor list and comparing the results with a constant. PiperOrigin-RevId: 346157648 Change-Id: If7fcf42a1632fd08292426551e191fd9919d67a4
1 parent cc518a5 commit e5ea20d

File tree

5 files changed

+113
-10
lines changed

5 files changed

+113
-10
lines changed

tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,23 @@ struct RemoveRedundantGetElement : public OpRewritePattern<GetElementOp> {
437437
}
438438
};
439439

440+
struct RemoveRedundantGetLength : public OpRewritePattern<GetLengthOp> {
441+
using OpRewritePattern<GetLengthOp>::OpRewritePattern;
442+
443+
LogicalResult matchAndRewrite(GetLengthOp gl_op,
444+
PatternRewriter &rewriter) const override {
445+
auto preceding_build_list = llvm::dyn_cast_or_null<BuildListOp>(
446+
gl_op.tensor_list().getDefiningOp());
447+
if (!preceding_build_list) {
448+
return failure();
449+
}
450+
int64_t num_tensors = preceding_build_list.getNumOperands();
451+
rewriter.replaceOpWithNewOp<ConstantOp>(gl_op,
452+
rewriter.getIndexAttr(num_tensors));
453+
return success();
454+
}
455+
};
456+
440457
struct BuildConstantListAsAttr : public OpRewritePattern<BuildListOp> {
441458
using OpRewritePattern<BuildListOp>::OpRewritePattern;
442459

@@ -477,6 +494,11 @@ void GetElementOp::getCanonicalizationPatterns(
477494
results.insert<RemoveRedundantGetElement>(context);
478495
}
479496

497+
void GetLengthOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
498+
MLIRContext *context) {
499+
results.insert<RemoveRedundantGetLength>(context);
500+
}
501+
480502
void BuildListOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
481503
MLIRContext *context) {
482504
results.insert<BuildConstantListAsAttr>(context);

tensorflow/compiler/mlir/tfr/ir/tfr_ops.td

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,30 @@ def TFR_BuildListOp : TFR_Op<"build_list", [NoSideEffect]> {
349349
let hasCanonicalizer = 1;
350350
}
351351

352+
def TFR_GetLengthOp : TFR_Op<"get_length", [NoSideEffect]> {
353+
let description = [{
354+
The `get_length` operation returns the number of tensors for a
355+
tfr.tensor_list.
356+
357+
Example:
358+
359+
```mlir
360+
%2 = tfr.get_length(%1) : tfr.tensor -> index
361+
%2 = tfr.get_length %1 -> index
362+
```
363+
}];
364+
365+
let arguments = (ins TFR_TensorListType:$tensor_list);
366+
367+
let results = (outs Index:$out);
368+
369+
let hasCanonicalizer = 1;
370+
371+
let assemblyFormat = [{
372+
$tensor_list attr-dict `->` type($out)
373+
}];
374+
}
375+
352376
//===----------------------------------------------------------------------===//
353377
// Function related classes
354378
//===----------------------------------------------------------------------===//

tensorflow/compiler/mlir/tfr/python/tfr_gen.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -643,6 +643,15 @@ def _index_to_I64(self, value, ty):
643643
else:
644644
return value, ty
645645

646+
def _i64_to_index(self, value, ty):
647+
if ty == TFRTypes.I64:
648+
casted = self._ssa_name('casted')
649+
self._emit_with_loc('\n{} = index_cast {} : i64 to index'.format(
650+
casted, value))
651+
return casted, TFRTypes.INDEX
652+
else:
653+
return value, ty
654+
646655
def _value_to_tensor(self, value, ty, node):
647656
value, ty = self._index_to_I64(value, ty)
648657
cst_tensor = self._ssa_name('cst')
@@ -828,15 +837,19 @@ def visit_Call(self, node):
828837
if func_name == 'len':
829838
arg, ty = self.visit(node.args[0])
830839
ty = self._get_inferred_type(node.args[0], ty)
831-
assert ty == TFRTypes.TF_TENSOR_SHAPE_LIST, ty
832-
len_value = self._ssa_name('len')
833-
self._emit_with_loc(
834-
'\n{} = shape.rank {} : !shape.shape -> !shape.size'.format(
835-
len_value, arg), node)
836-
size_value = self._ssa_name('len_size')
837-
self._emit_with_loc(
838-
'\n{} = shape.size_to_index {} : !shape.size'.format(
839-
size_value, len_value), node)
840+
if ty == TFRTypes.TF_TENSOR_SHAPE_LIST:
841+
len_value = self._ssa_name('len')
842+
self._emit_with_loc(
843+
'\n{} = shape.rank {} : !shape.shape -> !shape.size'.format(
844+
len_value, arg), node)
845+
size_value = self._ssa_name('len_size')
846+
self._emit_with_loc(
847+
'\n{} = shape.size_to_index {} : !shape.size'.format(
848+
size_value, len_value), node)
849+
elif ty == TFRTypes.TENSOR_LIST:
850+
size_value = self._ssa_name('len')
851+
self._emit_with_loc(
852+
'\n{} = tfr.get_length {} -> index'.format(size_value, arg), node)
840853
return (size_value, TFRTypes.INDEX)
841854

842855
raise NotImplementedError('call operator not recognized: {} {}'.format(
@@ -845,7 +858,7 @@ def visit_Call(self, node):
845858
def visit_Compare(self, node):
846859
lhs, lhs_ty = self.visit(node.left)
847860
for op, right in zip(node.ops, node.comparators):
848-
rhs, _ = self.visit(right)
861+
rhs, rhs_ty = self.visit(right)
849862
if isinstance(op, ast.Eq):
850863
pred = 'eq'
851864
elif isinstance(op, ast.Lt):
@@ -870,6 +883,10 @@ def visit_Compare(self, node):
870883
code = 'cmpi'
871884
elif lhs_ty == TFRTypes.F32:
872885
code = 'cmpf'
886+
elif lhs_ty == TFRTypes.INDEX:
887+
code = 'cmpi'
888+
# TODO(fengliuai): the reverse type inference should solve the issue.
889+
rhs, _ = self._i64_to_index(rhs, rhs_ty)
873890
else:
874891
raise NotImplementedError('Compare operand type not recognized')
875892
self._emit_with_loc(

tensorflow/compiler/mlir/tfr/python/tfr_gen_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def _tfr_control_flow_range_for(x):
126126
return x_sum
127127

128128

129+
@composite.Composite('TestInputNOp')
130+
def _tfr_control_flow_tensor_list_size(ins):
131+
n = len(ins)
132+
if n == 1:
133+
return ins[0]
134+
else:
135+
return math_ops.AddN(ins)
136+
137+
129138
#--- test fn for tf ops ---
130139

131140

@@ -403,6 +412,24 @@ def test_tfr_control_flow(self):
403412
CHECK-NEXT: %{{.*}} = constant true
404413
CHECK-NEXT: tfr.return %[[for_stmt]] : !tfr.tensor
405414
CHECK-NEXT: }
415+
416+
CHECK-LABEL: tfr.func @tf__test_input_n_op(%ins: !tfr.tensor_list) -> (!tfr.tensor) {
417+
CHECK-NEXT: %[[len:.*]] = tfr.get_length %ins -> index
418+
CHECK-NEXT: %[[cst:.*]] = constant 1 : i64
419+
CHECK-NEXT: %[[casted:.*]] = index_cast %[[cst]] : i64 to index
420+
CHECK-NEXT: %[[eq:.*]] = cmpi "eq", %[[len]], %[[casted]] : index
421+
CHECK-NEXT: %[[if:.*]] = scf.if %[[eq]] -> (!tfr.tensor) {
422+
CHECK-NEXT: %{{.*}} = constant true
423+
CHECK-NEXT: %{{.*}} = constant 0 : index
424+
CHECK-NEXT: %[[elt:.*]] = tfr.get_element %ins[%cst_2] : (!tfr.tensor_list, index) -> !tfr.tensor
425+
CHECK-NEXT: scf.yield %[[elt]] : !tfr.tensor
426+
CHECK-NEXT: } else {
427+
CHECK-NEXT: %{{.*}} = constant true
428+
CHECK-NEXT: %[[AddN:.*]] = tfr.call @tf__add_n(%ins) : (!tfr.tensor_list) -> (!tfr.tensor)
429+
CHECK-NEXT: scf.yield %[[AddN]] : !tfr.tensor
430+
CHECK-NEXT: }
431+
CHECK-NEXT: tfr.return %[[if_stmt]] : !tfr.tensor
432+
CHECK-NEXT: }
406433
"""
407434
self._check_code(mlir_code, mlir_code_exp)
408435

tensorflow/compiler/mlir/tfr/tests/ops.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,19 @@ func @build_const_list() -> !tfr.attr {
260260

261261
// -----
262262

263+
// CHECK-LABEL: get_length
264+
// CANON-LABEL: get_length
265+
func @get_length(%arg0: !tfr.tensor<A>, %arg1: !tfr.tensor<B>) -> index {
266+
%0 = "tfr.build_list"(%arg0, %arg1) : (!tfr.tensor<A>, !tfr.tensor<B>) -> !tfr.tensor_list
267+
%1 = "tfr.get_length"(%0) : (!tfr.tensor_list) -> index
268+
return %1 : index
269+
270+
// CANON-NEXT: %[[c:.*]] = constant 2 : index
271+
// CANON-NEXT: return %[[c]] : index
272+
}
273+
274+
// -----
275+
263276
// CHECK-LABEL: tfr.func
264277
tfr.func @External(%arg0: !tfr.tensor<A>,
265278
%arg1: !tfr.tensor_list<C>,

0 commit comments

Comments
 (0)