@@ -643,6 +643,15 @@ def _index_to_I64(self, value, ty):
643643 else :
644644 return value , ty
645645
646+ def _i64_to_index (self , value , ty ):
647+ if ty == TFRTypes .I64 :
648+ casted = self ._ssa_name ('casted' )
649+ self ._emit_with_loc ('\n {} = index_cast {} : i64 to index' .format (
650+ casted , value ))
651+ return casted , TFRTypes .INDEX
652+ else :
653+ return value , ty
654+
646655 def _value_to_tensor (self , value , ty , node ):
647656 value , ty = self ._index_to_I64 (value , ty )
648657 cst_tensor = self ._ssa_name ('cst' )
@@ -828,15 +837,19 @@ def visit_Call(self, node):
828837 if func_name == 'len' :
829838 arg , ty = self .visit (node .args [0 ])
830839 ty = self ._get_inferred_type (node .args [0 ], ty )
831- assert ty == TFRTypes .TF_TENSOR_SHAPE_LIST , ty
832- len_value = self ._ssa_name ('len' )
833- self ._emit_with_loc (
834- '\n {} = shape.rank {} : !shape.shape -> !shape.size' .format (
835- len_value , arg ), node )
836- size_value = self ._ssa_name ('len_size' )
837- self ._emit_with_loc (
838- '\n {} = shape.size_to_index {} : !shape.size' .format (
839- size_value , len_value ), node )
840+ if ty == TFRTypes .TF_TENSOR_SHAPE_LIST :
841+ len_value = self ._ssa_name ('len' )
842+ self ._emit_with_loc (
843+ '\n {} = shape.rank {} : !shape.shape -> !shape.size' .format (
844+ len_value , arg ), node )
845+ size_value = self ._ssa_name ('len_size' )
846+ self ._emit_with_loc (
847+ '\n {} = shape.size_to_index {} : !shape.size' .format (
848+ size_value , len_value ), node )
849+ elif ty == TFRTypes .TENSOR_LIST :
850+ size_value = self ._ssa_name ('len' )
851+ self ._emit_with_loc (
852+ '\n {} = tfr.get_length {} -> index' .format (size_value , arg ), node )
840853 return (size_value , TFRTypes .INDEX )
841854
842855 raise NotImplementedError ('call operator not recognized: {} {}' .format (
@@ -845,7 +858,7 @@ def visit_Call(self, node):
845858 def visit_Compare (self , node ):
846859 lhs , lhs_ty = self .visit (node .left )
847860 for op , right in zip (node .ops , node .comparators ):
848- rhs , _ = self .visit (right )
861+ rhs , rhs_ty = self .visit (right )
849862 if isinstance (op , ast .Eq ):
850863 pred = 'eq'
851864 elif isinstance (op , ast .Lt ):
@@ -870,6 +883,10 @@ def visit_Compare(self, node):
870883 code = 'cmpi'
871884 elif lhs_ty == TFRTypes .F32 :
872885 code = 'cmpf'
886+ elif lhs_ty == TFRTypes .INDEX :
887+ code = 'cmpi'
888+ # TODO(fengliuai): the reverse type inference should solve the issue.
889+ rhs , _ = self ._i64_to_index (rhs , rhs_ty )
873890 else :
874891 raise NotImplementedError ('Compare operand type not recognized' )
875892 self ._emit_with_loc (
0 commit comments