-
-
Save praveen-g-ctt/68511af88d6b7350ae04432a02dc71d5 to your computer and use it in GitHub Desktop.
Support nms batch and class using concats
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @main(%arg0: !torch.vtensor<[3,8,4],f32>, %arg1: !torch.vtensor<[3,5,8],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[120,3],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { | |
%none = torch.constant.none | |
%int0 = torch.constant.int 0 | |
%int1 = torch.constant.int 1 | |
%int2 = torch.constant.int 2 | |
%int3 = torch.constant.int 3 | |
%int4 = torch.constant.int 4 | |
%float2.000000e00 = torch.constant.float 2.000000e+00 | |
%none_0 = torch.constant.none | |
%true = torch.constant.bool true | |
%0 = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float | |
%1 = torch.aten.min %arg1 : !torch.vtensor<[3,5,8],f32> -> !torch.vtensor<[],f32> | |
%2 = torch.aten.item %1 : !torch.vtensor<[],f32> -> !torch.float | |
%3 = torch.aten.ge.float %2, %0 : !torch.float, !torch.float -> !torch.bool | |
torch.runtime.assert %3, "unimplemented: score_threshold should be <= min(scores)" | |
%float0.000000e00 = torch.constant.float 0.000000e+00 | |
%4 = torch.aten.item %arg3 : !torch.vtensor<[1],f32> -> !torch.float | |
%5 = torch.aten.item %arg2 : !torch.vtensor<[1],si64> -> !torch.int | |
%6 = torch.aten.size.int %arg1, %int0 : !torch.vtensor<[3,5,8],f32>, !torch.int -> !torch.int | |
%7 = torch.aten.size.int %arg1, %int1 : !torch.vtensor<[3,5,8],f32>, !torch.int -> !torch.int | |
%8 = torch.aten.mul.int %6, %7 : !torch.int, !torch.int -> !torch.int | |
%9 = torch.aten.mul.int %8, %5 : !torch.int, !torch.int -> !torch.int | |
%10 = torch.prim.ListConstruct %9, %int3 : (!torch.int, !torch.int) -> !torch.list<int> | |
%11 = torch.aten.empty.memory_format %10, %int4, %none_0, %none_0, %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[120,3],si64> | |
%12:2 = torch.prim.Loop %6, %true, init(%11, %int0) { | |
^bb0(%arg5: !torch.int, %arg6: !torch.vtensor<[120,3],si64>, %arg7: !torch.int): | |
%14 = torch.aten.select.int %arg0, %int0, %arg5 : !torch.vtensor<[3,8,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[8,4],f32> | |
%15:2 = torch.prim.Loop %7, %true, init(%arg6, %arg7) { | |
^bb0(%arg8: !torch.int, %arg9: !torch.vtensor<[120,3],si64>, %arg10: !torch.int): | |
%16 = torch.aten.select.int %arg1, %int0, %arg5 : !torch.vtensor<[3,5,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,8],f32> | |
%17 = torch.aten.select.int %16, %int0, %arg8 : !torch.vtensor<[5,8],f32>, !torch.int, !torch.int -> !torch.vtensor<[8],f32> | |
%18 = torch.torchvision.nms %14, %17, %4 : !torch.vtensor<[8,4],f32>, !torch.vtensor<[8],f32>, !torch.float -> !torch.vtensor<[?],si64> | |
%19 = torch.aten.size.int %18, %int0 : !torch.vtensor<[?],si64>, !torch.int -> !torch.int | |
%20 = torch.aten.gt.int %19, %5 : !torch.int, !torch.int -> !torch.bool | |
%21 = torch.prim.If %20 -> (!torch.vtensor<[?],si64>) { | |
%37 = torch.aten.slice.Tensor %18, %int0, %int0, %5, %int1 : !torch.vtensor<[?],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?],si64> | |
torch.prim.If.yield %37 : !torch.vtensor<[?],si64> | |
} else { | |
torch.prim.If.yield %18 : !torch.vtensor<[?],si64> | |
} | |
%22 = torch.aten.unsqueeze %21, %int1 : !torch.vtensor<[?],si64>, !torch.int -> !torch.vtensor<[?,1],si64> | |
%23 = torch.aten.size.int %22, %int0 : !torch.vtensor<[?,1],si64>, !torch.int -> !torch.int | |
%24 = torch.prim.ListConstruct %23, %int1 : (!torch.int, !torch.int) -> !torch.list<int> | |
%25 = torch.aten.empty.memory_format %24, %int4, %none_0, %none_0, %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,1],si64> | |
%26 = torch.aten.fill.Scalar %25, %arg5 : !torch.vtensor<[?,1],si64>, !torch.int -> !torch.vtensor<[?,1],si64> | |
%27 = torch.aten.empty.memory_format %24, %int4, %none_0, %none_0, %none_0, %none_0 : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,1],si64> | |
%28 = torch.aten.fill.Scalar %27, %arg8 : !torch.vtensor<[?,1],si64>, !torch.int -> !torch.vtensor<[?,1],si64> | |
%29 = torch.prim.ListConstruct %28, %22 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[?,1],si64>) -> !torch.list<vtensor> | |
%30 = torch.aten.cat %29, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,2],si64> | |
%31 = torch.prim.ListConstruct %26, %30 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[?,2],si64>) -> !torch.list<vtensor> | |
%32 = torch.aten.cat %31, %int1 : !torch.list<vtensor>, !torch.int -> !torch.vtensor<[?,3],si64> | |
%33 = torch.aten.add.int %arg10, %23 : !torch.int, !torch.int -> !torch.int | |
%false = torch.constant.bool false | |
%34 = torch.aten.slice.Tensor %arg9, %int0, %arg10, %33, %int1 : !torch.vtensor<[120,3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,3],si64> | |
%35 = torch.aten.copy %34, %32, %false : !torch.vtensor<[?,3],si64>, !torch.vtensor<[?,3],si64>, !torch.bool -> !torch.vtensor<[?,3],si64> | |
%36 = torch.aten.slice_scatter %arg9, %35, %int0, %arg10, %33, %int1 : !torch.vtensor<[120,3],si64>, !torch.vtensor<[?,3],si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[120,3],si64> | |
torch.prim.Loop.condition %true, iter(%36, %33 : !torch.vtensor<[120,3],si64>, !torch.int) | |
} : (!torch.int, !torch.bool, !torch.vtensor<[120,3],si64>, !torch.int) -> (!torch.vtensor<[120,3],si64>, !torch.int) | |
torch.prim.Loop.condition %true, iter(%15#0, %15#1 : !torch.vtensor<[120,3],si64>, !torch.int) | |
} : (!torch.int, !torch.bool, !torch.vtensor<[120,3],si64>, !torch.int) -> (!torch.vtensor<[120,3],si64>, !torch.int) | |
%13 = torch.tensor_static_info_cast %12#0 : !torch.vtensor<[120,3],si64> to !torch.vtensor<[120,3],si64> | |
return %13 : !torch.vtensor<[120,3],si64> | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @main(%arg0: !torch.vtensor<[3,8,4],f32>, %arg1: !torch.vtensor<[3,5,8],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[120,3],si64> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { | |
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,8,4],f32> -> tensor<3x8x4xf32> | |
%1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3,5,8],f32> -> tensor<3x5x8xf32> | |
%2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%3 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%4 = torch_c.to_builtin_tensor %arg4 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%5 = torch.vtensor.literal(dense<1> : tensor<8xsi64>) : !torch.vtensor<[8],si64> | |
%6 = torch.vtensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> | |
%7 = torch_c.to_builtin_tensor %6 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%8 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> | |
%9 = torch_c.to_builtin_tensor %8 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%int8 = torch.constant.int 8 | |
%false = torch.constant.bool false | |
%int2 = torch.constant.int 2 | |
%c2_i64 = arith.constant 2 : i64 | |
%int15 = torch.constant.int 15 | |
%int5 = torch.constant.int 5 | |
%none = torch.constant.none | |
%int0 = torch.constant.int 0 | |
%c0_i64 = arith.constant 0 : i64 | |
%int1 = torch.constant.int 1 | |
%c1_i64 = arith.constant 1 : i64 | |
%int3 = torch.constant.int 3 | |
%int4 = torch.constant.int 4 | |
%c4_i64 = arith.constant 4 : i64 | |
%true = torch.constant.bool true | |
%c0 = arith.constant 0 : index | |
%extracted = tensor.extract %4[%c0] : tensor<1xf32> | |
%10 = arith.extf %extracted : f32 to f64 | |
%11 = torch_c.from_f64 %10 | |
%cst = arith.constant 0x7F800000 : f32 | |
%c1 = arith.constant 1 : index | |
%c0_0 = arith.constant 0 : index | |
%dim = tensor.dim %1, %c0_0 : tensor<3x5x8xf32> | |
%c1_1 = arith.constant 1 : index | |
%dim_2 = tensor.dim %1, %c1_1 : tensor<3x5x8xf32> | |
%c2 = arith.constant 2 : index | |
%dim_3 = tensor.dim %1, %c2 : tensor<3x5x8xf32> | |
%12 = tensor.empty() : tensor<f32> | |
%13 = linalg.fill ins(%cst : f32) outs(%12 : tensor<f32>) -> tensor<f32> | |
%14 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>], iterator_types = ["reduction", "reduction", "reduction"]} ins(%1 : tensor<3x5x8xf32>) outs(%13 : tensor<f32>) { | |
^bb0(%in: f32, %out: f32): | |
%31 = arith.minimumf %in, %out : f32 | |
linalg.yield %31 : f32 | |
} -> tensor<f32> | |
%cast = tensor.cast %14 : tensor<f32> to tensor<f32> | |
%15 = torch_c.from_builtin_tensor %cast : tensor<f32> -> !torch.vtensor<[],f32> | |
%16 = torch_c.to_builtin_tensor %15 : !torch.vtensor<[],f32> -> tensor<f32> | |
%c0_4 = arith.constant 0 : index | |
%extracted_5 = tensor.extract %16[] : tensor<f32> | |
%17 = arith.extf %extracted_5 : f32 to f64 | |
%18 = torch_c.from_f64 %17 | |
%19 = torch.aten.ge.float %18, %11 : !torch.float, !torch.float -> !torch.bool | |
torch.runtime.assert %19, "unimplemented: score_threshold should be <= min(scores)" | |
%c0_6 = arith.constant 0 : index | |
%extracted_7 = tensor.extract %3[%c0_6] : tensor<1xf32> | |
%20 = arith.extf %extracted_7 : f32 to f64 | |
%21 = torch_c.from_f64 %20 | |
%c0_8 = arith.constant 0 : index | |
%extracted_9 = tensor.extract %2[%c0_8] : tensor<1xi64> | |
%22 = torch_c.from_i64 %extracted_9 | |
%23 = torch_c.to_i64 %22 | |
%24 = torch.aten.mul.int %int15, %22 : !torch.int, !torch.int -> !torch.int | |
%25 = torch.prim.ListConstruct %24, %int3 : (!torch.int, !torch.int) -> !torch.list<int> | |
%26 = torch_c.to_i64 %24 | |
%c3_i64 = arith.constant 3 : i64 | |
%27 = arith.index_cast %26 : i64 to index | |
%c3 = arith.constant 3 : index | |
%28 = tensor.empty(%27) : tensor<?x3xi64> | |
%cast_10 = tensor.cast %28 : tensor<?x3xi64> to tensor<120x3xi64> | |
%29 = torch_c.from_builtin_tensor %cast_10 : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
%30:2 = torch.prim.Loop %int3, %true, init(%29, %int0) { | |
^bb0(%arg5: !torch.int, %arg6: !torch.vtensor<[120,3],si64>, %arg7: !torch.int): | |
%31 = torch_c.to_i64 %arg5 | |
%32 = torch.aten.lt.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool | |
%33 = torch.aten.Int.bool %32 : !torch.bool -> !torch.int | |
%34 = torch.aten.mul.int %33, %int3 : !torch.int, !torch.int -> !torch.int | |
%35 = torch.aten.add.int %arg5, %34 : !torch.int, !torch.int -> !torch.int | |
%36 = torch_c.to_i64 %35 | |
%37 = torch.aten.add.int %35, %int1 : !torch.int, !torch.int -> !torch.int | |
%38 = torch_c.to_i64 %37 | |
%c0_11 = arith.constant 0 : index | |
%c1_12 = arith.constant 1 : index | |
%c-1 = arith.constant -1 : index | |
%c0_13 = arith.constant 0 : index | |
%c3_14 = arith.constant 3 : index | |
%c1_15 = arith.constant 1 : index | |
%c8 = arith.constant 8 : index | |
%c2_16 = arith.constant 2 : index | |
%c4 = arith.constant 4 : index | |
%39 = arith.index_cast %c1_i64 : i64 to index | |
%c3_i64_17 = arith.constant 3 : i64 | |
%40 = arith.addi %36, %c3_i64_17 : i64 | |
%c0_i64_18 = arith.constant 0 : i64 | |
%41 = arith.cmpi sge, %36, %c0_i64_18 : i64 | |
%42 = arith.select %41, %36, %40 : i64 | |
%c0_i64_19 = arith.constant 0 : i64 | |
%43 = arith.cmpi slt, %42, %c0_i64_19 : i64 | |
%44 = arith.select %43, %c0_i64_19, %42 : i64 | |
%45 = arith.cmpi sgt, %44, %c3_i64_17 : i64 | |
%46 = arith.select %45, %c3_i64_17, %44 : i64 | |
%47 = arith.index_cast %46 : i64 to index | |
%48 = arith.index_cast %38 : i64 to index | |
%49 = arith.cmpi slt, %48, %c0_11 : index | |
%50 = arith.addi %48, %c3_14 : index | |
%51 = arith.select %49, %50, %48 : index | |
%52 = arith.cmpi slt, %51, %c0_11 : index | |
%53 = arith.select %52, %c-1, %51 : index | |
%54 = arith.cmpi sgt, %53, %c3_14 : index | |
%55 = arith.select %54, %c3_14, %53 : index | |
%c0_20 = arith.constant 0 : index | |
%c3_21 = arith.constant 3 : index | |
%c1_22 = arith.constant 1 : index | |
%c8_23 = arith.constant 8 : index | |
%c2_24 = arith.constant 2 : index | |
%c4_25 = arith.constant 4 : index | |
%56 = arith.subi %55, %47 : index | |
%57 = arith.cmpi sge, %39, %c0_11 : index | |
%58 = arith.select %57, %c1_12, %c-1 : index | |
%59 = arith.addi %56, %39 : index | |
%60 = arith.subi %59, %58 : index | |
%61 = arith.floordivsi %60, %39 : index | |
%62 = arith.cmpi slt, %61, %c0_11 : index | |
%63 = arith.select %62, %c0_11, %61 : index | |
%c1_26 = arith.constant 1 : index | |
%c0_27 = arith.constant 0 : index | |
%c3_28 = arith.constant 3 : index | |
%c1_29 = arith.constant 1 : index | |
%c8_30 = arith.constant 8 : index | |
%c2_31 = arith.constant 2 : index | |
%c4_32 = arith.constant 4 : index | |
%64 = arith.subi %c3_28, %c1_26 : index | |
%c0_33 = arith.constant 0 : index | |
%c3_34 = arith.constant 3 : index | |
%c1_35 = arith.constant 1 : index | |
%c8_36 = arith.constant 8 : index | |
%c2_37 = arith.constant 2 : index | |
%c4_38 = arith.constant 4 : index | |
%65 = tensor.empty() : tensor<3x8x4xf32> | |
%cst_39 = arith.constant 0.000000e+00 : f32 | |
%66 = linalg.fill ins(%cst_39 : f32) outs(%65 : tensor<3x8x4xf32>) -> tensor<3x8x4xf32> | |
%67 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x8x4xf32>) outs(%66 : tensor<3x8x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%75 = linalg.index 0 : index | |
%76 = linalg.index 1 : index | |
%77 = linalg.index 2 : index | |
%78 = arith.subi %64, %75 : index | |
%extracted_44 = tensor.extract %0[%78, %76, %77] : tensor<3x8x4xf32> | |
linalg.yield %extracted_44 : f32 | |
} -> tensor<3x8x4xf32> | |
%c0_40 = arith.constant 0 : index | |
%c0_41 = arith.constant 0 : index | |
%68 = arith.cmpi slt, %39, %c0_41 : index | |
%69 = math.absi %39 : index | |
%70 = arith.muli %63, %69 : index | |
%dim_42 = tensor.dim %0, %c0_40 : tensor<3x8x4xf32> | |
%71 = arith.subi %dim_42, %70 : index | |
%72 = arith.select %68, %71, %47 : index | |
%73 = arith.select %68, %67, %0 : tensor<3x8x4xf32> | |
%extracted_slice = tensor.extract_slice %73[%72, %c0_11, %c0_11] [%63, %c8_23, %c4_25] [%69, %c1_12, %c1_12] : tensor<3x8x4xf32> to tensor<?x?x?xf32> | |
%cast_43 = tensor.cast %extracted_slice : tensor<?x?x?xf32> to tensor<1x8x4xf32> | |
%collapsed = tensor.collapse_shape %cast_43 [[0, 1], [2]] : tensor<1x8x4xf32> into tensor<8x4xf32> | |
%74:2 = torch.prim.Loop %int5, %true, init(%arg6, %arg7) { | |
^bb0(%arg8: !torch.int, %arg9: !torch.vtensor<[120,3],si64>, %arg10: !torch.int): | |
%75 = torch_c.to_i64 %arg10 | |
%76 = torch_c.to_builtin_tensor %arg9 : !torch.vtensor<[120,3],si64> -> tensor<120x3xi64> | |
%77 = torch_c.to_i64 %arg8 | |
%78 = torch.aten.lt.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool | |
%79 = torch.aten.Int.bool %78 : !torch.bool -> !torch.int | |
%80 = torch.aten.mul.int %79, %int3 : !torch.int, !torch.int -> !torch.int | |
%81 = torch.aten.add.int %arg5, %80 : !torch.int, !torch.int -> !torch.int | |
%82 = torch_c.to_i64 %81 | |
%83 = torch.aten.add.int %81, %int1 : !torch.int, !torch.int -> !torch.int | |
%84 = torch_c.to_i64 %83 | |
%c0_44 = arith.constant 0 : index | |
%c1_45 = arith.constant 1 : index | |
%c-1_46 = arith.constant -1 : index | |
%c0_47 = arith.constant 0 : index | |
%c3_48 = arith.constant 3 : index | |
%c1_49 = arith.constant 1 : index | |
%c5 = arith.constant 5 : index | |
%c2_50 = arith.constant 2 : index | |
%c8_51 = arith.constant 8 : index | |
%85 = arith.index_cast %c1_i64 : i64 to index | |
%c3_i64_52 = arith.constant 3 : i64 | |
%86 = arith.addi %82, %c3_i64_52 : i64 | |
%c0_i64_53 = arith.constant 0 : i64 | |
%87 = arith.cmpi sge, %82, %c0_i64_53 : i64 | |
%88 = arith.select %87, %82, %86 : i64 | |
%c0_i64_54 = arith.constant 0 : i64 | |
%89 = arith.cmpi slt, %88, %c0_i64_54 : i64 | |
%90 = arith.select %89, %c0_i64_54, %88 : i64 | |
%91 = arith.cmpi sgt, %90, %c3_i64_52 : i64 | |
%92 = arith.select %91, %c3_i64_52, %90 : i64 | |
%93 = arith.index_cast %92 : i64 to index | |
%94 = arith.index_cast %84 : i64 to index | |
%95 = arith.cmpi slt, %94, %c0_44 : index | |
%96 = arith.addi %94, %c3_48 : index | |
%97 = arith.select %95, %96, %94 : index | |
%98 = arith.cmpi slt, %97, %c0_44 : index | |
%99 = arith.select %98, %c-1_46, %97 : index | |
%100 = arith.cmpi sgt, %99, %c3_48 : index | |
%101 = arith.select %100, %c3_48, %99 : index | |
%c0_55 = arith.constant 0 : index | |
%c3_56 = arith.constant 3 : index | |
%c1_57 = arith.constant 1 : index | |
%c5_58 = arith.constant 5 : index | |
%c2_59 = arith.constant 2 : index | |
%c8_60 = arith.constant 8 : index | |
%102 = arith.subi %101, %93 : index | |
%103 = arith.cmpi sge, %85, %c0_44 : index | |
%104 = arith.select %103, %c1_45, %c-1_46 : index | |
%105 = arith.addi %102, %85 : index | |
%106 = arith.subi %105, %104 : index | |
%107 = arith.floordivsi %106, %85 : index | |
%108 = arith.cmpi slt, %107, %c0_44 : index | |
%109 = arith.select %108, %c0_44, %107 : index | |
%c1_61 = arith.constant 1 : index | |
%c0_62 = arith.constant 0 : index | |
%c3_63 = arith.constant 3 : index | |
%c1_64 = arith.constant 1 : index | |
%c5_65 = arith.constant 5 : index | |
%c2_66 = arith.constant 2 : index | |
%c8_67 = arith.constant 8 : index | |
%110 = arith.subi %c3_63, %c1_61 : index | |
%c0_68 = arith.constant 0 : index | |
%c3_69 = arith.constant 3 : index | |
%c1_70 = arith.constant 1 : index | |
%c5_71 = arith.constant 5 : index | |
%c2_72 = arith.constant 2 : index | |
%c8_73 = arith.constant 8 : index | |
%111 = tensor.empty() : tensor<3x5x8xf32> | |
%cst_74 = arith.constant 0.000000e+00 : f32 | |
%112 = linalg.fill ins(%cst_74 : f32) outs(%111 : tensor<3x5x8xf32>) -> tensor<3x5x8xf32> | |
%113 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1 : tensor<3x5x8xf32>) outs(%112 : tensor<3x5x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%395 = linalg.index 0 : index | |
%396 = linalg.index 1 : index | |
%397 = linalg.index 2 : index | |
%398 = arith.subi %110, %395 : index | |
%extracted_283 = tensor.extract %1[%398, %396, %397] : tensor<3x5x8xf32> | |
linalg.yield %extracted_283 : f32 | |
} -> tensor<3x5x8xf32> | |
%c0_75 = arith.constant 0 : index | |
%c0_76 = arith.constant 0 : index | |
%114 = arith.cmpi slt, %85, %c0_76 : index | |
%115 = math.absi %85 : index | |
%116 = arith.muli %109, %115 : index | |
%dim_77 = tensor.dim %1, %c0_75 : tensor<3x5x8xf32> | |
%117 = arith.subi %dim_77, %116 : index | |
%118 = arith.select %114, %117, %93 : index | |
%119 = arith.select %114, %113, %1 : tensor<3x5x8xf32> | |
%extracted_slice_78 = tensor.extract_slice %119[%118, %c0_44, %c0_44] [%109, %c5_58, %c8_60] [%115, %c1_45, %c1_45] : tensor<3x5x8xf32> to tensor<?x?x?xf32> | |
%cast_79 = tensor.cast %extracted_slice_78 : tensor<?x?x?xf32> to tensor<1x5x8xf32> | |
%collapsed_80 = tensor.collapse_shape %cast_79 [[0, 1], [2]] : tensor<1x5x8xf32> into tensor<5x8xf32> | |
%120 = torch.aten.lt.int %arg8, %int0 : !torch.int, !torch.int -> !torch.bool | |
%121 = torch.aten.Int.bool %120 : !torch.bool -> !torch.int | |
%122 = torch.aten.mul.int %121, %int5 : !torch.int, !torch.int -> !torch.int | |
%123 = torch.aten.add.int %arg8, %122 : !torch.int, !torch.int -> !torch.int | |
%124 = torch_c.to_i64 %123 | |
%125 = torch.aten.add.int %123, %int1 : !torch.int, !torch.int -> !torch.int | |
%126 = torch_c.to_i64 %125 | |
%c0_81 = arith.constant 0 : index | |
%c1_82 = arith.constant 1 : index | |
%c-1_83 = arith.constant -1 : index | |
%c0_84 = arith.constant 0 : index | |
%c5_85 = arith.constant 5 : index | |
%c1_86 = arith.constant 1 : index | |
%c8_87 = arith.constant 8 : index | |
%127 = arith.index_cast %c1_i64 : i64 to index | |
%c5_i64 = arith.constant 5 : i64 | |
%128 = arith.addi %124, %c5_i64 : i64 | |
%c0_i64_88 = arith.constant 0 : i64 | |
%129 = arith.cmpi sge, %124, %c0_i64_88 : i64 | |
%130 = arith.select %129, %124, %128 : i64 | |
%c0_i64_89 = arith.constant 0 : i64 | |
%131 = arith.cmpi slt, %130, %c0_i64_89 : i64 | |
%132 = arith.select %131, %c0_i64_89, %130 : i64 | |
%133 = arith.cmpi sgt, %132, %c5_i64 : i64 | |
%134 = arith.select %133, %c5_i64, %132 : i64 | |
%135 = arith.index_cast %134 : i64 to index | |
%136 = arith.index_cast %126 : i64 to index | |
%137 = arith.cmpi slt, %136, %c0_81 : index | |
%138 = arith.addi %136, %c5_85 : index | |
%139 = arith.select %137, %138, %136 : index | |
%140 = arith.cmpi slt, %139, %c0_81 : index | |
%141 = arith.select %140, %c-1_83, %139 : index | |
%142 = arith.cmpi sgt, %141, %c5_85 : index | |
%143 = arith.select %142, %c5_85, %141 : index | |
%c0_90 = arith.constant 0 : index | |
%c5_91 = arith.constant 5 : index | |
%c1_92 = arith.constant 1 : index | |
%c8_93 = arith.constant 8 : index | |
%144 = arith.subi %143, %135 : index | |
%145 = arith.cmpi sge, %127, %c0_81 : index | |
%146 = arith.select %145, %c1_82, %c-1_83 : index | |
%147 = arith.addi %144, %127 : index | |
%148 = arith.subi %147, %146 : index | |
%149 = arith.floordivsi %148, %127 : index | |
%150 = arith.cmpi slt, %149, %c0_81 : index | |
%151 = arith.select %150, %c0_81, %149 : index | |
%c1_94 = arith.constant 1 : index | |
%c0_95 = arith.constant 0 : index | |
%c5_96 = arith.constant 5 : index | |
%c1_97 = arith.constant 1 : index | |
%c8_98 = arith.constant 8 : index | |
%152 = arith.subi %c5_96, %c1_94 : index | |
%c0_99 = arith.constant 0 : index | |
%c5_100 = arith.constant 5 : index | |
%c1_101 = arith.constant 1 : index | |
%c8_102 = arith.constant 8 : index | |
%153 = tensor.empty() : tensor<5x8xf32> | |
%cst_103 = arith.constant 0.000000e+00 : f32 | |
%154 = linalg.fill ins(%cst_103 : f32) outs(%153 : tensor<5x8xf32>) -> tensor<5x8xf32> | |
%155 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_80 : tensor<5x8xf32>) outs(%154 : tensor<5x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%395 = linalg.index 0 : index | |
%396 = linalg.index 1 : index | |
%397 = arith.subi %152, %395 : index | |
%extracted_283 = tensor.extract %collapsed_80[%397, %396] : tensor<5x8xf32> | |
linalg.yield %extracted_283 : f32 | |
} -> tensor<5x8xf32> | |
%c0_104 = arith.constant 0 : index | |
%c0_105 = arith.constant 0 : index | |
%156 = arith.cmpi slt, %127, %c0_105 : index | |
%157 = math.absi %127 : index | |
%158 = arith.muli %151, %157 : index | |
%dim_106 = tensor.dim %collapsed_80, %c0_104 : tensor<5x8xf32> | |
%159 = arith.subi %dim_106, %158 : index | |
%160 = arith.select %156, %159, %135 : index | |
%161 = arith.select %156, %155, %collapsed_80 : tensor<5x8xf32> | |
%extracted_slice_107 = tensor.extract_slice %161[%160, %c0_81] [%151, %c8_93] [%157, %c1_82] : tensor<5x8xf32> to tensor<?x?xf32> | |
%cast_108 = tensor.cast %extracted_slice_107 : tensor<?x?xf32> to tensor<1x8xf32> | |
%collapsed_109 = tensor.collapse_shape %cast_108 [[0, 1]] : tensor<1x8xf32> into tensor<8xf32> | |
%162 = torch_c.from_builtin_tensor %collapsed_109 : tensor<8xf32> -> !torch.vtensor<[8],f32> | |
%163 = torch_c.to_builtin_tensor %162 : !torch.vtensor<[8],f32> -> tensor<8xf32> | |
%c0_110 = arith.constant 0 : index | |
%c1_111 = arith.constant 1 : index | |
%c-1_112 = arith.constant -1 : index | |
%c0_113 = arith.constant 0 : index | |
%c8_114 = arith.constant 8 : index | |
%c1_115 = arith.constant 1 : index | |
%c4_116 = arith.constant 4 : index | |
%164 = arith.index_cast %c1_i64 : i64 to index | |
%c4_i64_117 = arith.constant 4 : i64 | |
%165 = arith.addi %c0_i64, %c4_i64_117 : i64 | |
%c0_i64_118 = arith.constant 0 : i64 | |
%166 = arith.cmpi sge, %c0_i64, %c0_i64_118 : i64 | |
%167 = arith.select %166, %c0_i64, %165 : i64 | |
%c0_i64_119 = arith.constant 0 : i64 | |
%168 = arith.cmpi slt, %167, %c0_i64_119 : i64 | |
%169 = arith.select %168, %c0_i64_119, %167 : i64 | |
%170 = arith.cmpi sgt, %169, %c4_i64_117 : i64 | |
%171 = arith.select %170, %c4_i64_117, %169 : i64 | |
%172 = arith.index_cast %171 : i64 to index | |
%173 = arith.index_cast %c2_i64 : i64 to index | |
%174 = arith.cmpi slt, %173, %c0_110 : index | |
%175 = arith.addi %173, %c4_116 : index | |
%176 = arith.select %174, %175, %173 : index | |
%177 = arith.cmpi slt, %176, %c0_110 : index | |
%178 = arith.select %177, %c-1_112, %176 : index | |
%179 = arith.cmpi sgt, %178, %c4_116 : index | |
%180 = arith.select %179, %c4_116, %178 : index | |
%c0_120 = arith.constant 0 : index | |
%c8_121 = arith.constant 8 : index | |
%c1_122 = arith.constant 1 : index | |
%c4_123 = arith.constant 4 : index | |
%181 = arith.subi %180, %172 : index | |
%182 = arith.cmpi sge, %164, %c0_110 : index | |
%183 = arith.select %182, %c1_111, %c-1_112 : index | |
%184 = arith.addi %181, %164 : index | |
%185 = arith.subi %184, %183 : index | |
%186 = arith.floordivsi %185, %164 : index | |
%187 = arith.cmpi slt, %186, %c0_110 : index | |
%188 = arith.select %187, %c0_110, %186 : index | |
%c1_124 = arith.constant 1 : index | |
%c0_125 = arith.constant 0 : index | |
%c8_126 = arith.constant 8 : index | |
%c1_127 = arith.constant 1 : index | |
%c4_128 = arith.constant 4 : index | |
%189 = arith.subi %c4_128, %c1_124 : index | |
%c0_129 = arith.constant 0 : index | |
%c8_130 = arith.constant 8 : index | |
%c1_131 = arith.constant 1 : index | |
%c4_132 = arith.constant 4 : index | |
%190 = tensor.empty() : tensor<8x4xf32> | |
%cst_133 = arith.constant 0.000000e+00 : f32 | |
%191 = linalg.fill ins(%cst_133 : f32) outs(%190 : tensor<8x4xf32>) -> tensor<8x4xf32> | |
%192 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<8x4xf32>) outs(%191 : tensor<8x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%395 = linalg.index 0 : index | |
%396 = linalg.index 1 : index | |
%397 = arith.subi %189, %396 : index | |
%extracted_283 = tensor.extract %collapsed[%395, %397] : tensor<8x4xf32> | |
linalg.yield %extracted_283 : f32 | |
} -> tensor<8x4xf32> | |
%c1_134 = arith.constant 1 : index | |
%c0_135 = arith.constant 0 : index | |
%193 = arith.cmpi slt, %164, %c0_135 : index | |
%194 = math.absi %164 : index | |
%195 = arith.muli %188, %194 : index | |
%dim_136 = tensor.dim %collapsed, %c1_134 : tensor<8x4xf32> | |
%196 = arith.subi %dim_136, %195 : index | |
%197 = arith.select %193, %196, %172 : index | |
%198 = arith.select %193, %192, %collapsed : tensor<8x4xf32> | |
%extracted_slice_137 = tensor.extract_slice %198[%c0_110, %197] [%c8_121, %188] [%c1_111, %194] : tensor<8x4xf32> to tensor<?x?xf32> | |
%cast_138 = tensor.cast %extracted_slice_137 : tensor<?x?xf32> to tensor<8x2xf32> | |
%c0_139 = arith.constant 0 : index | |
%c1_140 = arith.constant 1 : index | |
%c-1_141 = arith.constant -1 : index | |
%c0_142 = arith.constant 0 : index | |
%c8_143 = arith.constant 8 : index | |
%c1_144 = arith.constant 1 : index | |
%c4_145 = arith.constant 4 : index | |
%199 = arith.index_cast %c1_i64 : i64 to index | |
%c4_i64_146 = arith.constant 4 : i64 | |
%200 = arith.addi %c2_i64, %c4_i64_146 : i64 | |
%c0_i64_147 = arith.constant 0 : i64 | |
%201 = arith.cmpi sge, %c2_i64, %c0_i64_147 : i64 | |
%202 = arith.select %201, %c2_i64, %200 : i64 | |
%c0_i64_148 = arith.constant 0 : i64 | |
%203 = arith.cmpi slt, %202, %c0_i64_148 : i64 | |
%204 = arith.select %203, %c0_i64_148, %202 : i64 | |
%205 = arith.cmpi sgt, %204, %c4_i64_146 : i64 | |
%206 = arith.select %205, %c4_i64_146, %204 : i64 | |
%207 = arith.index_cast %206 : i64 to index | |
%208 = arith.index_cast %c4_i64 : i64 to index | |
%209 = arith.cmpi slt, %208, %c0_139 : index | |
%210 = arith.addi %208, %c4_145 : index | |
%211 = arith.select %209, %210, %208 : index | |
%212 = arith.cmpi slt, %211, %c0_139 : index | |
%213 = arith.select %212, %c-1_141, %211 : index | |
%214 = arith.cmpi sgt, %213, %c4_145 : index | |
%215 = arith.select %214, %c4_145, %213 : index | |
%c0_149 = arith.constant 0 : index | |
%c8_150 = arith.constant 8 : index | |
%c1_151 = arith.constant 1 : index | |
%c4_152 = arith.constant 4 : index | |
%216 = arith.subi %215, %207 : index | |
%217 = arith.cmpi sge, %199, %c0_139 : index | |
%218 = arith.select %217, %c1_140, %c-1_141 : index | |
%219 = arith.addi %216, %199 : index | |
%220 = arith.subi %219, %218 : index | |
%221 = arith.floordivsi %220, %199 : index | |
%222 = arith.cmpi slt, %221, %c0_139 : index | |
%223 = arith.select %222, %c0_139, %221 : index | |
%c1_153 = arith.constant 1 : index | |
%c0_154 = arith.constant 0 : index | |
%c8_155 = arith.constant 8 : index | |
%c1_156 = arith.constant 1 : index | |
%c4_157 = arith.constant 4 : index | |
%224 = arith.subi %c4_157, %c1_153 : index | |
%c0_158 = arith.constant 0 : index | |
%c8_159 = arith.constant 8 : index | |
%c1_160 = arith.constant 1 : index | |
%c4_161 = arith.constant 4 : index | |
%225 = tensor.empty() : tensor<8x4xf32> | |
%cst_162 = arith.constant 0.000000e+00 : f32 | |
%226 = linalg.fill ins(%cst_162 : f32) outs(%225 : tensor<8x4xf32>) -> tensor<8x4xf32> | |
%227 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<8x4xf32>) outs(%226 : tensor<8x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%395 = linalg.index 0 : index | |
%396 = linalg.index 1 : index | |
%397 = arith.subi %224, %396 : index | |
%extracted_283 = tensor.extract %collapsed[%395, %397] : tensor<8x4xf32> | |
linalg.yield %extracted_283 : f32 | |
} -> tensor<8x4xf32> | |
%c1_163 = arith.constant 1 : index | |
%c0_164 = arith.constant 0 : index | |
%228 = arith.cmpi slt, %199, %c0_164 : index | |
%229 = math.absi %199 : index | |
%230 = arith.muli %223, %229 : index | |
%dim_165 = tensor.dim %collapsed, %c1_163 : tensor<8x4xf32> | |
%231 = arith.subi %dim_165, %230 : index | |
%232 = arith.select %228, %231, %207 : index | |
%233 = arith.select %228, %227, %collapsed : tensor<8x4xf32> | |
%extracted_slice_166 = tensor.extract_slice %233[%c0_139, %232] [%c8_150, %223] [%c1_140, %229] : tensor<8x4xf32> to tensor<?x?xf32> | |
%cast_167 = tensor.cast %extracted_slice_166 : tensor<?x?xf32> to tensor<8x2xf32> | |
%c1_168 = arith.constant 1 : index | |
%c0_169 = arith.constant 0 : index | |
%c8_170 = arith.constant 8 : index | |
%c1_171 = arith.constant 1 : index | |
%c2_172 = arith.constant 2 : index | |
%c0_173 = arith.constant 0 : index | |
%c8_174 = arith.constant 8 : index | |
%c1_175 = arith.constant 1 : index | |
%c2_176 = arith.constant 2 : index | |
%234 = tensor.empty() : tensor<8x2xf32> | |
%235 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_167, %cast_138 : tensor<8x2xf32>, tensor<8x2xf32>) outs(%234 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_283: f32, %out: f32): | |
%395 = arith.sitofp %c1_i64 : i64 to f32 | |
%396 = arith.mulf %in_283, %395 : f32 | |
%397 = arith.subf %in, %396 : f32 | |
linalg.yield %397 : f32 | |
} -> tensor<8x2xf32> | |
%cast_177 = tensor.cast %235 : tensor<8x2xf32> to tensor<8x2xf32> | |
%cst_178 = arith.constant 1.000000e+00 : f32 | |
%c1_179 = arith.constant 1 : index | |
%c0_180 = arith.constant 0 : index | |
%dim_181 = tensor.dim %cast_177, %c0_180 : tensor<8x2xf32> | |
%c1_182 = arith.constant 1 : index | |
%dim_183 = tensor.dim %cast_177, %c1_182 : tensor<8x2xf32> | |
%236 = tensor.empty(%dim_181) : tensor<?xf32> | |
%237 = linalg.fill ins(%cst_178 : f32) outs(%236 : tensor<?xf32>) -> tensor<?xf32> | |
%238 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%cast_177 : tensor<8x2xf32>) outs(%237 : tensor<?xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%395 = arith.mulf %in, %out : f32 | |
linalg.yield %395 : f32 | |
} -> tensor<?xf32> | |
%cast_184 = tensor.cast %238 : tensor<?xf32> to tensor<8xf32> | |
%239 = tensor.empty() : tensor<8xi64> | |
%240 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%239 : tensor<8xi64>) { | |
^bb0(%out: i64): | |
%395 = linalg.index 0 : index | |
%396 = arith.index_cast %395 : index to i64 | |
linalg.yield %396 : i64 | |
} -> tensor<8xi64> | |
%241:2 = iree_linalg_ext.sort dimension(0) outs(%163, %240 : tensor<8xf32>, tensor<8xi64>) { | |
^bb0(%arg11: f32, %arg12: f32, %arg13: i64, %arg14: i64): | |
%395 = arith.cmpf oge, %arg11, %arg12 : f32 | |
iree_linalg_ext.yield %395 : i1 | |
} -> tensor<8xf32>, tensor<8xi64> | |
%242 = torch_c.from_builtin_tensor %241#1 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%243 = torch_c.to_builtin_tensor %242 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%244 = torch.prim.ListConstruct %int8 : (!torch.int) -> !torch.list<int> | |
%c8_i64 = arith.constant 8 : i64 | |
%c8_185 = arith.constant 8 : index | |
%245 = tensor.empty() : tensor<8xi64> | |
%cast_186 = tensor.cast %245 : tensor<8xi64> to tensor<8xi64> | |
%246 = torch_c.from_builtin_tensor %cast_186 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%247:3 = torch.prim.Loop %int8, %true, init(%5, %246, %int0) { | |
^bb0(%arg11: !torch.int, %arg12: !torch.vtensor<[8],si64>, %arg13: !torch.vtensor<[8],si64>, %arg14: !torch.int): | |
%395 = torch_c.to_i64 %arg11 | |
%396 = torch_c.to_i64 %arg14 | |
%397 = torch_c.to_builtin_tensor %arg13 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%398 = torch_c.to_builtin_tensor %arg12 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%399 = torch.aten.lt.int %arg11, %int0 : !torch.int, !torch.int -> !torch.bool | |
%400 = torch.aten.Int.bool %399 : !torch.bool -> !torch.int | |
%401 = torch.aten.mul.int %400, %int8 : !torch.int, !torch.int -> !torch.int | |
%402 = torch.aten.add.int %arg11, %401 : !torch.int, !torch.int -> !torch.int | |
%403 = torch_c.to_i64 %402 | |
%404 = torch.aten.add.int %402, %int1 : !torch.int, !torch.int -> !torch.int | |
%405 = torch_c.to_i64 %404 | |
%c0_283 = arith.constant 0 : index | |
%c1_284 = arith.constant 1 : index | |
%c-1_285 = arith.constant -1 : index | |
%c0_286 = arith.constant 0 : index | |
%c8_287 = arith.constant 8 : index | |
%406 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_288 = arith.constant 8 : i64 | |
%407 = arith.addi %403, %c8_i64_288 : i64 | |
%c0_i64_289 = arith.constant 0 : i64 | |
%408 = arith.cmpi sge, %403, %c0_i64_289 : i64 | |
%409 = arith.select %408, %403, %407 : i64 | |
%c0_i64_290 = arith.constant 0 : i64 | |
%410 = arith.cmpi slt, %409, %c0_i64_290 : i64 | |
%411 = arith.select %410, %c0_i64_290, %409 : i64 | |
%412 = arith.cmpi sgt, %411, %c8_i64_288 : i64 | |
%413 = arith.select %412, %c8_i64_288, %411 : i64 | |
%414 = arith.index_cast %413 : i64 to index | |
%415 = arith.index_cast %405 : i64 to index | |
%416 = arith.cmpi slt, %415, %c0_283 : index | |
%417 = arith.addi %415, %c8_287 : index | |
%418 = arith.select %416, %417, %415 : index | |
%419 = arith.cmpi slt, %418, %c0_283 : index | |
%420 = arith.select %419, %c-1_285, %418 : index | |
%421 = arith.cmpi sgt, %420, %c8_287 : index | |
%422 = arith.select %421, %c8_287, %420 : index | |
%c0_291 = arith.constant 0 : index | |
%c8_292 = arith.constant 8 : index | |
%423 = arith.subi %422, %414 : index | |
%424 = arith.cmpi sge, %406, %c0_283 : index | |
%425 = arith.select %424, %c1_284, %c-1_285 : index | |
%426 = arith.addi %423, %406 : index | |
%427 = arith.subi %426, %425 : index | |
%428 = arith.floordivsi %427, %406 : index | |
%429 = arith.cmpi slt, %428, %c0_283 : index | |
%430 = arith.select %429, %c0_283, %428 : index | |
%c1_293 = arith.constant 1 : index | |
%c0_294 = arith.constant 0 : index | |
%c8_295 = arith.constant 8 : index | |
%431 = arith.subi %c8_295, %c1_293 : index | |
%c0_296 = arith.constant 0 : index | |
%c8_297 = arith.constant 8 : index | |
%432 = tensor.empty() : tensor<8xi64> | |
%c0_i64_298 = arith.constant 0 : i64 | |
%433 = linalg.fill ins(%c0_i64_298 : i64) outs(%432 : tensor<8xi64>) -> tensor<8xi64> | |
%434 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%398 : tensor<8xi64>) outs(%433 : tensor<8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%446 = linalg.index 0 : index | |
%447 = arith.subi %431, %446 : index | |
%extracted_306 = tensor.extract %398[%447] : tensor<8xi64> | |
linalg.yield %extracted_306 : i64 | |
} -> tensor<8xi64> | |
%c0_299 = arith.constant 0 : index | |
%c0_300 = arith.constant 0 : index | |
%435 = arith.cmpi slt, %406, %c0_300 : index | |
%436 = math.absi %406 : index | |
%437 = arith.muli %430, %436 : index | |
%dim_301 = tensor.dim %398, %c0_299 : tensor<8xi64> | |
%438 = arith.subi %dim_301, %437 : index | |
%439 = arith.select %435, %438, %414 : index | |
%440 = arith.select %435, %434, %398 : tensor<8xi64> | |
%extracted_slice_302 = tensor.extract_slice %440[%439] [%430] [%436] : tensor<8xi64> to tensor<?xi64> | |
%cast_303 = tensor.cast %extracted_slice_302 : tensor<?xi64> to tensor<1xi64> | |
%441 = torch_c.from_builtin_tensor %cast_303 : tensor<1xi64> -> !torch.vtensor<[1],si64> | |
%442 = torch_c.to_builtin_tensor %441 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%c0_304 = arith.constant 0 : index | |
%extracted_305 = tensor.extract %442[%c0_304] : tensor<1xi64> | |
%443 = torch_c.from_i64 %extracted_305 | |
%444 = torch.aten.Bool.int %443 : !torch.int -> !torch.bool | |
%445:3 = torch.prim.If %444 -> (!torch.vtensor<[8],si64>, !torch.vtensor<[8],si64>, !torch.int) { | |
%446 = torch.aten.lt.int %arg11, %int0 : !torch.int, !torch.int -> !torch.bool | |
%447 = torch.aten.Int.bool %446 : !torch.bool -> !torch.int | |
%448 = torch.aten.mul.int %447, %int8 : !torch.int, !torch.int -> !torch.int | |
%449 = torch.aten.add.int %arg11, %448 : !torch.int, !torch.int -> !torch.int | |
%450 = torch_c.to_i64 %449 | |
%451 = torch.aten.add.int %449, %int1 : !torch.int, !torch.int -> !torch.int | |
%452 = torch_c.to_i64 %451 | |
%c0_306 = arith.constant 0 : index | |
%c1_307 = arith.constant 1 : index | |
%c-1_308 = arith.constant -1 : index | |
%c0_309 = arith.constant 0 : index | |
%c8_310 = arith.constant 8 : index | |
%453 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_311 = arith.constant 8 : i64 | |
%454 = arith.addi %450, %c8_i64_311 : i64 | |
%c0_i64_312 = arith.constant 0 : i64 | |
%455 = arith.cmpi sge, %450, %c0_i64_312 : i64 | |
%456 = arith.select %455, %450, %454 : i64 | |
%c0_i64_313 = arith.constant 0 : i64 | |
%457 = arith.cmpi slt, %456, %c0_i64_313 : i64 | |
%458 = arith.select %457, %c0_i64_313, %456 : i64 | |
%459 = arith.cmpi sgt, %458, %c8_i64_311 : i64 | |
%460 = arith.select %459, %c8_i64_311, %458 : i64 | |
%461 = arith.index_cast %460 : i64 to index | |
%462 = arith.index_cast %452 : i64 to index | |
%463 = arith.cmpi slt, %462, %c0_306 : index | |
%464 = arith.addi %462, %c8_310 : index | |
%465 = arith.select %463, %464, %462 : index | |
%466 = arith.cmpi slt, %465, %c0_306 : index | |
%467 = arith.select %466, %c-1_308, %465 : index | |
%468 = arith.cmpi sgt, %467, %c8_310 : index | |
%469 = arith.select %468, %c8_310, %467 : index | |
%c0_314 = arith.constant 0 : index | |
%c8_315 = arith.constant 8 : index | |
%470 = arith.subi %469, %461 : index | |
%471 = arith.cmpi sge, %453, %c0_306 : index | |
%472 = arith.select %471, %c1_307, %c-1_308 : index | |
%473 = arith.addi %470, %453 : index | |
%474 = arith.subi %473, %472 : index | |
%475 = arith.floordivsi %474, %453 : index | |
%476 = arith.cmpi slt, %475, %c0_306 : index | |
%477 = arith.select %476, %c0_306, %475 : index | |
%c1_316 = arith.constant 1 : index | |
%c0_317 = arith.constant 0 : index | |
%c8_318 = arith.constant 8 : index | |
%478 = arith.subi %c8_318, %c1_316 : index | |
%c0_319 = arith.constant 0 : index | |
%c8_320 = arith.constant 8 : index | |
%479 = tensor.empty() : tensor<8xi64> | |
%c0_i64_321 = arith.constant 0 : i64 | |
%480 = linalg.fill ins(%c0_i64_321 : i64) outs(%479 : tensor<8xi64>) -> tensor<8xi64> | |
%481 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%243 : tensor<8xi64>) outs(%480 : tensor<8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%682 = linalg.index 0 : index | |
%683 = arith.subi %478, %682 : index | |
%extracted_505 = tensor.extract %243[%683] : tensor<8xi64> | |
linalg.yield %extracted_505 : i64 | |
} -> tensor<8xi64> | |
%c0_322 = arith.constant 0 : index | |
%c0_323 = arith.constant 0 : index | |
%482 = arith.cmpi slt, %453, %c0_323 : index | |
%483 = math.absi %453 : index | |
%484 = arith.muli %477, %483 : index | |
%dim_324 = tensor.dim %243, %c0_322 : tensor<8xi64> | |
%485 = arith.subi %dim_324, %484 : index | |
%486 = arith.select %482, %485, %461 : index | |
%487 = arith.select %482, %481, %243 : tensor<8xi64> | |
%extracted_slice_325 = tensor.extract_slice %487[%486] [%477] [%483] : tensor<8xi64> to tensor<?xi64> | |
%cast_326 = tensor.cast %extracted_slice_325 : tensor<?xi64> to tensor<1xi64> | |
%488 = torch_c.from_builtin_tensor %cast_326 : tensor<1xi64> -> !torch.vtensor<[1],si64> | |
%489 = torch_c.to_builtin_tensor %488 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%490 = torch.aten.add.int %arg14, %int1 : !torch.int, !torch.int -> !torch.int | |
%491 = torch_c.to_i64 %490 | |
%c0_327 = arith.constant 0 : index | |
%c1_328 = arith.constant 1 : index | |
%c-1_329 = arith.constant -1 : index | |
%c0_330 = arith.constant 0 : index | |
%c8_331 = arith.constant 8 : index | |
%492 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_332 = arith.constant 8 : i64 | |
%493 = arith.addi %396, %c8_i64_332 : i64 | |
%c0_i64_333 = arith.constant 0 : i64 | |
%494 = arith.cmpi sge, %396, %c0_i64_333 : i64 | |
%495 = arith.select %494, %396, %493 : i64 | |
%c0_i64_334 = arith.constant 0 : i64 | |
%496 = arith.cmpi slt, %495, %c0_i64_334 : i64 | |
%497 = arith.select %496, %c0_i64_334, %495 : i64 | |
%498 = arith.cmpi sgt, %497, %c8_i64_332 : i64 | |
%499 = arith.select %498, %c8_i64_332, %497 : i64 | |
%500 = arith.index_cast %499 : i64 to index | |
%501 = arith.index_cast %491 : i64 to index | |
%502 = arith.cmpi slt, %501, %c0_327 : index | |
%503 = arith.addi %501, %c8_331 : index | |
%504 = arith.select %502, %503, %501 : index | |
%505 = arith.cmpi slt, %504, %c0_327 : index | |
%506 = arith.select %505, %c-1_329, %504 : index | |
%507 = arith.cmpi sgt, %506, %c8_331 : index | |
%508 = arith.select %507, %c8_331, %506 : index | |
%c0_335 = arith.constant 0 : index | |
%c8_336 = arith.constant 8 : index | |
%509 = arith.subi %508, %500 : index | |
%510 = arith.cmpi sge, %492, %c0_327 : index | |
%511 = arith.select %510, %c1_328, %c-1_329 : index | |
%512 = arith.addi %509, %492 : index | |
%513 = arith.subi %512, %511 : index | |
%514 = arith.floordivsi %513, %492 : index | |
%515 = arith.cmpi slt, %514, %c0_327 : index | |
%516 = arith.select %515, %c0_327, %514 : index | |
%cast_337 = tensor.cast %cast_326 : tensor<1xi64> to tensor<?xi64> | |
%inserted_slice_338 = tensor.insert_slice %cast_337 into %397[%500] [%516] [%492] : tensor<?xi64> into tensor<8xi64> | |
%cast_339 = tensor.cast %inserted_slice_338 : tensor<8xi64> to tensor<8xi64> | |
%517 = torch_c.from_builtin_tensor %cast_339 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%c0_340 = arith.constant 0 : index | |
%extracted_341 = tensor.extract %489[%c0_340] : tensor<1xi64> | |
%518 = torch_c.from_i64 %extracted_341 | |
%519 = torch_c.to_i64 %518 | |
%520 = torch.aten.add.int %518, %int1 : !torch.int, !torch.int -> !torch.int | |
%521 = torch_c.to_i64 %520 | |
%c0_342 = arith.constant 0 : index | |
%c1_343 = arith.constant 1 : index | |
%c-1_344 = arith.constant -1 : index | |
%c0_345 = arith.constant 0 : index | |
%c8_346 = arith.constant 8 : index | |
%c1_347 = arith.constant 1 : index | |
%c4_348 = arith.constant 4 : index | |
%522 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_349 = arith.constant 8 : i64 | |
%523 = arith.addi %519, %c8_i64_349 : i64 | |
%c0_i64_350 = arith.constant 0 : i64 | |
%524 = arith.cmpi sge, %519, %c0_i64_350 : i64 | |
%525 = arith.select %524, %519, %523 : i64 | |
%c0_i64_351 = arith.constant 0 : i64 | |
%526 = arith.cmpi slt, %525, %c0_i64_351 : i64 | |
%527 = arith.select %526, %c0_i64_351, %525 : i64 | |
%528 = arith.cmpi sgt, %527, %c8_i64_349 : i64 | |
%529 = arith.select %528, %c8_i64_349, %527 : i64 | |
%530 = arith.index_cast %529 : i64 to index | |
%531 = arith.index_cast %521 : i64 to index | |
%532 = arith.cmpi slt, %531, %c0_342 : index | |
%533 = arith.addi %531, %c8_346 : index | |
%534 = arith.select %532, %533, %531 : index | |
%535 = arith.cmpi slt, %534, %c0_342 : index | |
%536 = arith.select %535, %c-1_344, %534 : index | |
%537 = arith.cmpi sgt, %536, %c8_346 : index | |
%538 = arith.select %537, %c8_346, %536 : index | |
%c0_352 = arith.constant 0 : index | |
%c8_353 = arith.constant 8 : index | |
%c1_354 = arith.constant 1 : index | |
%c4_355 = arith.constant 4 : index | |
%539 = arith.subi %538, %530 : index | |
%540 = arith.cmpi sge, %522, %c0_342 : index | |
%541 = arith.select %540, %c1_343, %c-1_344 : index | |
%542 = arith.addi %539, %522 : index | |
%543 = arith.subi %542, %541 : index | |
%544 = arith.floordivsi %543, %522 : index | |
%545 = arith.cmpi slt, %544, %c0_342 : index | |
%546 = arith.select %545, %c0_342, %544 : index | |
%c1_356 = arith.constant 1 : index | |
%c0_357 = arith.constant 0 : index | |
%c8_358 = arith.constant 8 : index | |
%c1_359 = arith.constant 1 : index | |
%c4_360 = arith.constant 4 : index | |
%547 = arith.subi %c8_358, %c1_356 : index | |
%c0_361 = arith.constant 0 : index | |
%c8_362 = arith.constant 8 : index | |
%c1_363 = arith.constant 1 : index | |
%c4_364 = arith.constant 4 : index | |
%548 = tensor.empty() : tensor<8x4xf32> | |
%cst_365 = arith.constant 0.000000e+00 : f32 | |
%549 = linalg.fill ins(%cst_365 : f32) outs(%548 : tensor<8x4xf32>) -> tensor<8x4xf32> | |
%550 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<8x4xf32>) outs(%549 : tensor<8x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%682 = linalg.index 0 : index | |
%683 = linalg.index 1 : index | |
%684 = arith.subi %547, %682 : index | |
%extracted_505 = tensor.extract %collapsed[%684, %683] : tensor<8x4xf32> | |
linalg.yield %extracted_505 : f32 | |
} -> tensor<8x4xf32> | |
%c0_366 = arith.constant 0 : index | |
%c0_367 = arith.constant 0 : index | |
%551 = arith.cmpi slt, %522, %c0_367 : index | |
%552 = math.absi %522 : index | |
%553 = arith.muli %546, %552 : index | |
%dim_368 = tensor.dim %collapsed, %c0_366 : tensor<8x4xf32> | |
%554 = arith.subi %dim_368, %553 : index | |
%555 = arith.select %551, %554, %530 : index | |
%556 = arith.select %551, %550, %collapsed : tensor<8x4xf32> | |
%extracted_slice_369 = tensor.extract_slice %556[%555, %c0_342] [%546, %c4_355] [%552, %c1_343] : tensor<8x4xf32> to tensor<?x?xf32> | |
%cast_370 = tensor.cast %extracted_slice_369 : tensor<?x?xf32> to tensor<1x4xf32> | |
%c0_371 = arith.constant 0 : index | |
%c1_372 = arith.constant 1 : index | |
%c-1_373 = arith.constant -1 : index | |
%c0_374 = arith.constant 0 : index | |
%c1_375 = arith.constant 1 : index | |
%c1_376 = arith.constant 1 : index | |
%c4_377 = arith.constant 4 : index | |
%557 = arith.index_cast %c1_i64 : i64 to index | |
%c4_i64_378 = arith.constant 4 : i64 | |
%558 = arith.addi %c0_i64, %c4_i64_378 : i64 | |
%c0_i64_379 = arith.constant 0 : i64 | |
%559 = arith.cmpi sge, %c0_i64, %c0_i64_379 : i64 | |
%560 = arith.select %559, %c0_i64, %558 : i64 | |
%c0_i64_380 = arith.constant 0 : i64 | |
%561 = arith.cmpi slt, %560, %c0_i64_380 : i64 | |
%562 = arith.select %561, %c0_i64_380, %560 : i64 | |
%563 = arith.cmpi sgt, %562, %c4_i64_378 : i64 | |
%564 = arith.select %563, %c4_i64_378, %562 : i64 | |
%565 = arith.index_cast %564 : i64 to index | |
%566 = arith.index_cast %c2_i64 : i64 to index | |
%567 = arith.cmpi slt, %566, %c0_371 : index | |
%568 = arith.addi %566, %c4_377 : index | |
%569 = arith.select %567, %568, %566 : index | |
%570 = arith.cmpi slt, %569, %c0_371 : index | |
%571 = arith.select %570, %c-1_373, %569 : index | |
%572 = arith.cmpi sgt, %571, %c4_377 : index | |
%573 = arith.select %572, %c4_377, %571 : index | |
%c0_381 = arith.constant 0 : index | |
%c1_382 = arith.constant 1 : index | |
%c1_383 = arith.constant 1 : index | |
%c4_384 = arith.constant 4 : index | |
%574 = arith.subi %573, %565 : index | |
%575 = arith.cmpi sge, %557, %c0_371 : index | |
%576 = arith.select %575, %c1_372, %c-1_373 : index | |
%577 = arith.addi %574, %557 : index | |
%578 = arith.subi %577, %576 : index | |
%579 = arith.floordivsi %578, %557 : index | |
%580 = arith.cmpi slt, %579, %c0_371 : index | |
%581 = arith.select %580, %c0_371, %579 : index | |
%c1_385 = arith.constant 1 : index | |
%c0_386 = arith.constant 0 : index | |
%c1_387 = arith.constant 1 : index | |
%c1_388 = arith.constant 1 : index | |
%c4_389 = arith.constant 4 : index | |
%582 = arith.subi %c4_389, %c1_385 : index | |
%c0_390 = arith.constant 0 : index | |
%c1_391 = arith.constant 1 : index | |
%c1_392 = arith.constant 1 : index | |
%c4_393 = arith.constant 4 : index | |
%583 = tensor.empty() : tensor<1x4xf32> | |
%cst_394 = arith.constant 0.000000e+00 : f32 | |
%584 = linalg.fill ins(%cst_394 : f32) outs(%583 : tensor<1x4xf32>) -> tensor<1x4xf32> | |
%585 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_370 : tensor<1x4xf32>) outs(%584 : tensor<1x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%682 = linalg.index 0 : index | |
%683 = linalg.index 1 : index | |
%684 = arith.subi %582, %683 : index | |
%extracted_505 = tensor.extract %cast_370[%682, %684] : tensor<1x4xf32> | |
linalg.yield %extracted_505 : f32 | |
} -> tensor<1x4xf32> | |
%c1_395 = arith.constant 1 : index | |
%c0_396 = arith.constant 0 : index | |
%586 = arith.cmpi slt, %557, %c0_396 : index | |
%587 = math.absi %557 : index | |
%588 = arith.muli %581, %587 : index | |
%dim_397 = tensor.dim %cast_370, %c1_395 : tensor<1x4xf32> | |
%589 = arith.subi %dim_397, %588 : index | |
%590 = arith.select %586, %589, %565 : index | |
%591 = arith.select %586, %585, %cast_370 : tensor<1x4xf32> | |
%extracted_slice_398 = tensor.extract_slice %591[%c0_371, %590] [%c1_382, %581] [%c1_372, %587] : tensor<1x4xf32> to tensor<?x?xf32> | |
%cast_399 = tensor.cast %extracted_slice_398 : tensor<?x?xf32> to tensor<1x2xf32> | |
%c0_400 = arith.constant 0 : index | |
%c1_401 = arith.constant 1 : index | |
%c-1_402 = arith.constant -1 : index | |
%c0_403 = arith.constant 0 : index | |
%c1_404 = arith.constant 1 : index | |
%c1_405 = arith.constant 1 : index | |
%c4_406 = arith.constant 4 : index | |
%592 = arith.index_cast %c1_i64 : i64 to index | |
%c4_i64_407 = arith.constant 4 : i64 | |
%593 = arith.addi %c2_i64, %c4_i64_407 : i64 | |
%c0_i64_408 = arith.constant 0 : i64 | |
%594 = arith.cmpi sge, %c2_i64, %c0_i64_408 : i64 | |
%595 = arith.select %594, %c2_i64, %593 : i64 | |
%c0_i64_409 = arith.constant 0 : i64 | |
%596 = arith.cmpi slt, %595, %c0_i64_409 : i64 | |
%597 = arith.select %596, %c0_i64_409, %595 : i64 | |
%598 = arith.cmpi sgt, %597, %c4_i64_407 : i64 | |
%599 = arith.select %598, %c4_i64_407, %597 : i64 | |
%600 = arith.index_cast %599 : i64 to index | |
%601 = arith.index_cast %c4_i64 : i64 to index | |
%602 = arith.cmpi slt, %601, %c0_400 : index | |
%603 = arith.addi %601, %c4_406 : index | |
%604 = arith.select %602, %603, %601 : index | |
%605 = arith.cmpi slt, %604, %c0_400 : index | |
%606 = arith.select %605, %c-1_402, %604 : index | |
%607 = arith.cmpi sgt, %606, %c4_406 : index | |
%608 = arith.select %607, %c4_406, %606 : index | |
%c0_410 = arith.constant 0 : index | |
%c1_411 = arith.constant 1 : index | |
%c1_412 = arith.constant 1 : index | |
%c4_413 = arith.constant 4 : index | |
%609 = arith.subi %608, %600 : index | |
%610 = arith.cmpi sge, %592, %c0_400 : index | |
%611 = arith.select %610, %c1_401, %c-1_402 : index | |
%612 = arith.addi %609, %592 : index | |
%613 = arith.subi %612, %611 : index | |
%614 = arith.floordivsi %613, %592 : index | |
%615 = arith.cmpi slt, %614, %c0_400 : index | |
%616 = arith.select %615, %c0_400, %614 : index | |
%c1_414 = arith.constant 1 : index | |
%c0_415 = arith.constant 0 : index | |
%c1_416 = arith.constant 1 : index | |
%c1_417 = arith.constant 1 : index | |
%c4_418 = arith.constant 4 : index | |
%617 = arith.subi %c4_418, %c1_414 : index | |
%c0_419 = arith.constant 0 : index | |
%c1_420 = arith.constant 1 : index | |
%c1_421 = arith.constant 1 : index | |
%c4_422 = arith.constant 4 : index | |
%618 = tensor.empty() : tensor<1x4xf32> | |
%cst_423 = arith.constant 0.000000e+00 : f32 | |
%619 = linalg.fill ins(%cst_423 : f32) outs(%618 : tensor<1x4xf32>) -> tensor<1x4xf32> | |
%620 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_370 : tensor<1x4xf32>) outs(%619 : tensor<1x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%682 = linalg.index 0 : index | |
%683 = linalg.index 1 : index | |
%684 = arith.subi %617, %683 : index | |
%extracted_505 = tensor.extract %cast_370[%682, %684] : tensor<1x4xf32> | |
linalg.yield %extracted_505 : f32 | |
} -> tensor<1x4xf32> | |
%c1_424 = arith.constant 1 : index | |
%c0_425 = arith.constant 0 : index | |
%621 = arith.cmpi slt, %592, %c0_425 : index | |
%622 = math.absi %592 : index | |
%623 = arith.muli %616, %622 : index | |
%dim_426 = tensor.dim %cast_370, %c1_424 : tensor<1x4xf32> | |
%624 = arith.subi %dim_426, %623 : index | |
%625 = arith.select %621, %624, %600 : index | |
%626 = arith.select %621, %620, %cast_370 : tensor<1x4xf32> | |
%extracted_slice_427 = tensor.extract_slice %626[%c0_400, %625] [%c1_411, %616] [%c1_401, %622] : tensor<1x4xf32> to tensor<?x?xf32> | |
%cast_428 = tensor.cast %extracted_slice_427 : tensor<?x?xf32> to tensor<1x2xf32> | |
%c1_429 = arith.constant 1 : index | |
%c0_430 = arith.constant 0 : index | |
%c8_431 = arith.constant 8 : index | |
%c1_432 = arith.constant 1 : index | |
%c2_433 = arith.constant 2 : index | |
%c1_434 = arith.constant 1 : index | |
%c2_435 = arith.constant 2 : index | |
%627 = tensor.empty() : tensor<8x2xf32> | |
%628 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_138, %cast_399 : tensor<8x2xf32>, tensor<1x2xf32>) outs(%627 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_505: f32, %out: f32): | |
%682 = arith.cmpf ogt, %in, %in_505 : f32 | |
%683 = arith.select %682, %in, %in_505 : f32 | |
linalg.yield %683 : f32 | |
} -> tensor<8x2xf32> | |
%cast_436 = tensor.cast %628 : tensor<8x2xf32> to tensor<8x2xf32> | |
%c1_437 = arith.constant 1 : index | |
%c0_438 = arith.constant 0 : index | |
%c8_439 = arith.constant 8 : index | |
%c1_440 = arith.constant 1 : index | |
%c2_441 = arith.constant 2 : index | |
%c1_442 = arith.constant 1 : index | |
%c2_443 = arith.constant 2 : index | |
%629 = tensor.empty() : tensor<8x2xf32> | |
%630 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_167, %cast_428 : tensor<8x2xf32>, tensor<1x2xf32>) outs(%629 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_505: f32, %out: f32): | |
%682 = arith.cmpf olt, %in, %in_505 : f32 | |
%683 = arith.select %682, %in, %in_505 : f32 | |
linalg.yield %683 : f32 | |
} -> tensor<8x2xf32> | |
%cast_444 = tensor.cast %630 : tensor<8x2xf32> to tensor<8x2xf32> | |
%c1_445 = arith.constant 1 : index | |
%c0_446 = arith.constant 0 : index | |
%c8_447 = arith.constant 8 : index | |
%c1_448 = arith.constant 1 : index | |
%c2_449 = arith.constant 2 : index | |
%c0_450 = arith.constant 0 : index | |
%c8_451 = arith.constant 8 : index | |
%c1_452 = arith.constant 1 : index | |
%c2_453 = arith.constant 2 : index | |
%631 = tensor.empty() : tensor<8x2xf32> | |
%632 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_444, %cast_436 : tensor<8x2xf32>, tensor<8x2xf32>) outs(%631 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_505: f32, %out: f32): | |
%682 = arith.sitofp %c1_i64 : i64 to f32 | |
%683 = arith.mulf %in_505, %682 : f32 | |
%684 = arith.subf %in, %683 : f32 | |
linalg.yield %684 : f32 | |
} -> tensor<8x2xf32> | |
%cast_454 = tensor.cast %632 : tensor<8x2xf32> to tensor<8x2xf32> | |
%c1_455 = arith.constant 1 : index | |
%c0_456 = arith.constant 0 : index | |
%c8_457 = arith.constant 8 : index | |
%c1_458 = arith.constant 1 : index | |
%c2_459 = arith.constant 2 : index | |
%633 = tensor.empty() : tensor<8x2xf32> | |
%634 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_454, %7 : tensor<8x2xf32>, tensor<1xf32>) outs(%633 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_505: f32, %out: f32): | |
%682 = arith.cmpf ogt, %in, %in_505 : f32 | |
%683 = arith.select %682, %in, %in_505 : f32 | |
linalg.yield %683 : f32 | |
} -> tensor<8x2xf32> | |
%cast_460 = tensor.cast %634 : tensor<8x2xf32> to tensor<8x2xf32> | |
%cst_461 = arith.constant 1.000000e+00 : f32 | |
%c1_462 = arith.constant 1 : index | |
%c0_463 = arith.constant 0 : index | |
%dim_464 = tensor.dim %cast_460, %c0_463 : tensor<8x2xf32> | |
%c1_465 = arith.constant 1 : index | |
%dim_466 = tensor.dim %cast_460, %c1_465 : tensor<8x2xf32> | |
%635 = tensor.empty(%dim_464) : tensor<?xf32> | |
%636 = linalg.fill ins(%cst_461 : f32) outs(%635 : tensor<?xf32>) -> tensor<?xf32> | |
%637 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%cast_460 : tensor<8x2xf32>) outs(%636 : tensor<?xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%682 = arith.mulf %in, %out : f32 | |
linalg.yield %682 : f32 | |
} -> tensor<?xf32> | |
%cast_467 = tensor.cast %637 : tensor<?xf32> to tensor<8xf32> | |
%638 = torch.aten.add.int %arg11, %int1 : !torch.int, !torch.int -> !torch.int | |
%639 = torch_c.to_i64 %638 | |
%c0_468 = arith.constant 0 : index | |
%c1_469 = arith.constant 1 : index | |
%c-1_470 = arith.constant -1 : index | |
%c0_471 = arith.constant 0 : index | |
%c8_472 = arith.constant 8 : index | |
%640 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_473 = arith.constant 8 : i64 | |
%641 = arith.addi %395, %c8_i64_473 : i64 | |
%c0_i64_474 = arith.constant 0 : i64 | |
%642 = arith.cmpi sge, %395, %c0_i64_474 : i64 | |
%643 = arith.select %642, %395, %641 : i64 | |
%c0_i64_475 = arith.constant 0 : i64 | |
%644 = arith.cmpi slt, %643, %c0_i64_475 : i64 | |
%645 = arith.select %644, %c0_i64_475, %643 : i64 | |
%646 = arith.cmpi sgt, %645, %c8_i64_473 : i64 | |
%647 = arith.select %646, %c8_i64_473, %645 : i64 | |
%648 = arith.index_cast %647 : i64 to index | |
%649 = arith.index_cast %639 : i64 to index | |
%650 = arith.cmpi slt, %649, %c0_468 : index | |
%651 = arith.addi %649, %c8_472 : index | |
%652 = arith.select %650, %651, %649 : index | |
%653 = arith.cmpi slt, %652, %c0_468 : index | |
%654 = arith.select %653, %c-1_470, %652 : index | |
%655 = arith.cmpi sgt, %654, %c8_472 : index | |
%656 = arith.select %655, %c8_472, %654 : index | |
%c0_476 = arith.constant 0 : index | |
%c8_477 = arith.constant 8 : index | |
%657 = arith.subi %656, %648 : index | |
%658 = arith.cmpi sge, %640, %c0_468 : index | |
%659 = arith.select %658, %c1_469, %c-1_470 : index | |
%660 = arith.addi %657, %640 : index | |
%661 = arith.subi %660, %659 : index | |
%662 = arith.floordivsi %661, %640 : index | |
%663 = arith.cmpi slt, %662, %c0_468 : index | |
%664 = arith.select %663, %c0_468, %662 : index | |
%c1_478 = arith.constant 1 : index | |
%c0_479 = arith.constant 0 : index | |
%c8_480 = arith.constant 8 : index | |
%665 = arith.subi %c8_480, %c1_478 : index | |
%c0_481 = arith.constant 0 : index | |
%c8_482 = arith.constant 8 : index | |
%666 = tensor.empty() : tensor<8xf32> | |
%cst_483 = arith.constant 0.000000e+00 : f32 | |
%667 = linalg.fill ins(%cst_483 : f32) outs(%666 : tensor<8xf32>) -> tensor<8xf32> | |
%668 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_184 : tensor<8xf32>) outs(%667 : tensor<8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%682 = linalg.index 0 : index | |
%683 = arith.subi %665, %682 : index | |
%extracted_505 = tensor.extract %cast_184[%683] : tensor<8xf32> | |
linalg.yield %extracted_505 : f32 | |
} -> tensor<8xf32> | |
%c0_484 = arith.constant 0 : index | |
%c0_485 = arith.constant 0 : index | |
%669 = arith.cmpi slt, %640, %c0_485 : index | |
%670 = math.absi %640 : index | |
%671 = arith.muli %664, %670 : index | |
%dim_486 = tensor.dim %cast_184, %c0_484 : tensor<8xf32> | |
%672 = arith.subi %dim_486, %671 : index | |
%673 = arith.select %669, %672, %648 : index | |
%674 = arith.select %669, %668, %cast_184 : tensor<8xf32> | |
%extracted_slice_487 = tensor.extract_slice %674[%673] [%664] [%670] : tensor<8xf32> to tensor<?xf32> | |
%cast_488 = tensor.cast %extracted_slice_487 : tensor<?xf32> to tensor<1xf32> | |
%c1_489 = arith.constant 1 : index | |
%c0_490 = arith.constant 0 : index | |
%c8_491 = arith.constant 8 : index | |
%675 = tensor.empty() : tensor<8xf32> | |
%676 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_184, %cast_488 : tensor<8xf32>, tensor<1xf32>) outs(%675 : tensor<8xf32>) { | |
^bb0(%in: f32, %in_505: f32, %out: f32): | |
%682 = arith.sitofp %c1_i64 : i64 to f32 | |
%683 = arith.mulf %in_505, %682 : f32 | |
%684 = arith.addf %in, %683 : f32 | |
linalg.yield %684 : f32 | |
} -> tensor<8xf32> | |
%cast_492 = tensor.cast %676 : tensor<8xf32> to tensor<8xf32> | |
%c1_493 = arith.constant 1 : index | |
%c0_494 = arith.constant 0 : index | |
%c8_495 = arith.constant 8 : index | |
%c0_496 = arith.constant 0 : index | |
%c8_497 = arith.constant 8 : index | |
%677 = tensor.empty() : tensor<8xf32> | |
%678 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_492, %cast_467 : tensor<8xf32>, tensor<8xf32>) outs(%677 : tensor<8xf32>) { | |
^bb0(%in: f32, %in_505: f32, %out: f32): | |
%682 = arith.sitofp %c1_i64 : i64 to f32 | |
%683 = arith.mulf %in_505, %682 : f32 | |
%684 = arith.subf %in, %683 : f32 | |
linalg.yield %684 : f32 | |
} -> tensor<8xf32> | |
%cast_498 = tensor.cast %678 : tensor<8xf32> to tensor<8xf32> | |
%c1_499 = arith.constant 1 : index | |
%c0_500 = arith.constant 0 : index | |
%c8_501 = arith.constant 8 : index | |
%c0_502 = arith.constant 0 : index | |
%c8_503 = arith.constant 8 : index | |
%679 = tensor.empty() : tensor<8xf32> | |
%680 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_467, %cast_498 : tensor<8xf32>, tensor<8xf32>) outs(%679 : tensor<8xf32>) { | |
^bb0(%in: f32, %in_505: f32, %out: f32): | |
%682 = arith.divf %in, %in_505 : f32 | |
linalg.yield %682 : f32 | |
} -> tensor<8xf32> | |
%cast_504 = tensor.cast %680 : tensor<8xf32> to tensor<8xf32> | |
%681 = torch.prim.Loop %int8, %true, init(%arg12) { | |
^bb0(%arg15: !torch.int, %arg16: !torch.vtensor<[8],si64>): | |
%682 = torch_c.to_builtin_tensor %arg16 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%683 = torch.aten.add.int %arg15, %arg11 : !torch.int, !torch.int -> !torch.int | |
%684 = torch.aten.add.int %683, %int1 : !torch.int, !torch.int -> !torch.int | |
%685 = torch_c.to_i64 %684 | |
%686 = torch.aten.lt.int %684, %int8 : !torch.int, !torch.int -> !torch.bool | |
%687 = torch.prim.If %686 -> (!torch.vtensor<[8],si64>) { | |
%688 = torch.aten.lt.int %684, %int0 : !torch.int, !torch.int -> !torch.bool | |
%689 = torch.aten.Int.bool %688 : !torch.bool -> !torch.int | |
%690 = torch.aten.mul.int %689, %int8 : !torch.int, !torch.int -> !torch.int | |
%691 = torch.aten.add.int %684, %690 : !torch.int, !torch.int -> !torch.int | |
%692 = torch_c.to_i64 %691 | |
%693 = torch.aten.add.int %691, %int1 : !torch.int, !torch.int -> !torch.int | |
%694 = torch_c.to_i64 %693 | |
%c0_505 = arith.constant 0 : index | |
%c1_506 = arith.constant 1 : index | |
%c-1_507 = arith.constant -1 : index | |
%c0_508 = arith.constant 0 : index | |
%c8_509 = arith.constant 8 : index | |
%695 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_510 = arith.constant 8 : i64 | |
%696 = arith.addi %692, %c8_i64_510 : i64 | |
%c0_i64_511 = arith.constant 0 : i64 | |
%697 = arith.cmpi sge, %692, %c0_i64_511 : i64 | |
%698 = arith.select %697, %692, %696 : i64 | |
%c0_i64_512 = arith.constant 0 : i64 | |
%699 = arith.cmpi slt, %698, %c0_i64_512 : i64 | |
%700 = arith.select %699, %c0_i64_512, %698 : i64 | |
%701 = arith.cmpi sgt, %700, %c8_i64_510 : i64 | |
%702 = arith.select %701, %c8_i64_510, %700 : i64 | |
%703 = arith.index_cast %702 : i64 to index | |
%704 = arith.index_cast %694 : i64 to index | |
%705 = arith.cmpi slt, %704, %c0_505 : index | |
%706 = arith.addi %704, %c8_509 : index | |
%707 = arith.select %705, %706, %704 : index | |
%708 = arith.cmpi slt, %707, %c0_505 : index | |
%709 = arith.select %708, %c-1_507, %707 : index | |
%710 = arith.cmpi sgt, %709, %c8_509 : index | |
%711 = arith.select %710, %c8_509, %709 : index | |
%c0_513 = arith.constant 0 : index | |
%c8_514 = arith.constant 8 : index | |
%712 = arith.subi %711, %703 : index | |
%713 = arith.cmpi sge, %695, %c0_505 : index | |
%714 = arith.select %713, %c1_506, %c-1_507 : index | |
%715 = arith.addi %712, %695 : index | |
%716 = arith.subi %715, %714 : index | |
%717 = arith.floordivsi %716, %695 : index | |
%718 = arith.cmpi slt, %717, %c0_505 : index | |
%719 = arith.select %718, %c0_505, %717 : index | |
%c1_515 = arith.constant 1 : index | |
%c0_516 = arith.constant 0 : index | |
%c8_517 = arith.constant 8 : index | |
%720 = arith.subi %c8_517, %c1_515 : index | |
%c0_518 = arith.constant 0 : index | |
%c8_519 = arith.constant 8 : index | |
%721 = tensor.empty() : tensor<8xi64> | |
%c0_i64_520 = arith.constant 0 : i64 | |
%722 = linalg.fill ins(%c0_i64_520 : i64) outs(%721 : tensor<8xi64>) -> tensor<8xi64> | |
%723 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%243 : tensor<8xi64>) outs(%722 : tensor<8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%777 = linalg.index 0 : index | |
%778 = arith.subi %720, %777 : index | |
%extracted_551 = tensor.extract %243[%778] : tensor<8xi64> | |
linalg.yield %extracted_551 : i64 | |
} -> tensor<8xi64> | |
%c0_521 = arith.constant 0 : index | |
%c0_522 = arith.constant 0 : index | |
%724 = arith.cmpi slt, %695, %c0_522 : index | |
%725 = math.absi %695 : index | |
%726 = arith.muli %719, %725 : index | |
%dim_523 = tensor.dim %243, %c0_521 : tensor<8xi64> | |
%727 = arith.subi %dim_523, %726 : index | |
%728 = arith.select %724, %727, %703 : index | |
%729 = arith.select %724, %723, %243 : tensor<8xi64> | |
%extracted_slice_524 = tensor.extract_slice %729[%728] [%719] [%725] : tensor<8xi64> to tensor<?xi64> | |
%cast_525 = tensor.cast %extracted_slice_524 : tensor<?xi64> to tensor<1xi64> | |
%730 = torch_c.from_builtin_tensor %cast_525 : tensor<1xi64> -> !torch.vtensor<[1],si64> | |
%731 = torch_c.to_builtin_tensor %730 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%c0_526 = arith.constant 0 : index | |
%extracted_527 = tensor.extract %731[%c0_526] : tensor<1xi64> | |
%732 = torch_c.from_i64 %extracted_527 | |
%733 = torch_c.to_i64 %732 | |
%734 = torch.aten.add.int %732, %int1 : !torch.int, !torch.int -> !torch.int | |
%735 = torch_c.to_i64 %734 | |
%c0_528 = arith.constant 0 : index | |
%c1_529 = arith.constant 1 : index | |
%c-1_530 = arith.constant -1 : index | |
%c0_531 = arith.constant 0 : index | |
%c8_532 = arith.constant 8 : index | |
%736 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_533 = arith.constant 8 : i64 | |
%737 = arith.addi %733, %c8_i64_533 : i64 | |
%c0_i64_534 = arith.constant 0 : i64 | |
%738 = arith.cmpi sge, %733, %c0_i64_534 : i64 | |
%739 = arith.select %738, %733, %737 : i64 | |
%c0_i64_535 = arith.constant 0 : i64 | |
%740 = arith.cmpi slt, %739, %c0_i64_535 : i64 | |
%741 = arith.select %740, %c0_i64_535, %739 : i64 | |
%742 = arith.cmpi sgt, %741, %c8_i64_533 : i64 | |
%743 = arith.select %742, %c8_i64_533, %741 : i64 | |
%744 = arith.index_cast %743 : i64 to index | |
%745 = arith.index_cast %735 : i64 to index | |
%746 = arith.cmpi slt, %745, %c0_528 : index | |
%747 = arith.addi %745, %c8_532 : index | |
%748 = arith.select %746, %747, %745 : index | |
%749 = arith.cmpi slt, %748, %c0_528 : index | |
%750 = arith.select %749, %c-1_530, %748 : index | |
%751 = arith.cmpi sgt, %750, %c8_532 : index | |
%752 = arith.select %751, %c8_532, %750 : index | |
%c0_536 = arith.constant 0 : index | |
%c8_537 = arith.constant 8 : index | |
%753 = arith.subi %752, %744 : index | |
%754 = arith.cmpi sge, %736, %c0_528 : index | |
%755 = arith.select %754, %c1_529, %c-1_530 : index | |
%756 = arith.addi %753, %736 : index | |
%757 = arith.subi %756, %755 : index | |
%758 = arith.floordivsi %757, %736 : index | |
%759 = arith.cmpi slt, %758, %c0_528 : index | |
%760 = arith.select %759, %c0_528, %758 : index | |
%c1_538 = arith.constant 1 : index | |
%c0_539 = arith.constant 0 : index | |
%c8_540 = arith.constant 8 : index | |
%761 = arith.subi %c8_540, %c1_538 : index | |
%c0_541 = arith.constant 0 : index | |
%c8_542 = arith.constant 8 : index | |
%762 = tensor.empty() : tensor<8xf32> | |
%cst_543 = arith.constant 0.000000e+00 : f32 | |
%763 = linalg.fill ins(%cst_543 : f32) outs(%762 : tensor<8xf32>) -> tensor<8xf32> | |
%764 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_504 : tensor<8xf32>) outs(%763 : tensor<8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%777 = linalg.index 0 : index | |
%778 = arith.subi %761, %777 : index | |
%extracted_551 = tensor.extract %cast_504[%778] : tensor<8xf32> | |
linalg.yield %extracted_551 : f32 | |
} -> tensor<8xf32> | |
%c0_544 = arith.constant 0 : index | |
%c0_545 = arith.constant 0 : index | |
%765 = arith.cmpi slt, %736, %c0_545 : index | |
%766 = math.absi %736 : index | |
%767 = arith.muli %760, %766 : index | |
%dim_546 = tensor.dim %cast_504, %c0_544 : tensor<8xf32> | |
%768 = arith.subi %dim_546, %767 : index | |
%769 = arith.select %765, %768, %744 : index | |
%770 = arith.select %765, %764, %cast_504 : tensor<8xf32> | |
%extracted_slice_547 = tensor.extract_slice %770[%769] [%760] [%766] : tensor<8xf32> to tensor<?xf32> | |
%cast_548 = tensor.cast %extracted_slice_547 : tensor<?xf32> to tensor<1xf32> | |
%771 = torch_c.from_builtin_tensor %cast_548 : tensor<1xf32> -> !torch.vtensor<[1],f32> | |
%772 = torch_c.to_builtin_tensor %771 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%c0_549 = arith.constant 0 : index | |
%extracted_550 = tensor.extract %772[%c0_549] : tensor<1xf32> | |
%773 = arith.extf %extracted_550 : f32 to f64 | |
%774 = torch_c.from_f64 %773 | |
%775 = torch.aten.gt.float %774, %21 : !torch.float, !torch.float -> !torch.bool | |
%776 = torch.prim.If %775 -> (!torch.vtensor<[8],si64>) { | |
%777 = torch.aten.add.int %684, %int1 : !torch.int, !torch.int -> !torch.int | |
%778 = torch_c.to_i64 %777 | |
%c0_551 = arith.constant 0 : index | |
%c1_552 = arith.constant 1 : index | |
%c-1_553 = arith.constant -1 : index | |
%c0_554 = arith.constant 0 : index | |
%c8_555 = arith.constant 8 : index | |
%779 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_556 = arith.constant 8 : i64 | |
%780 = arith.addi %685, %c8_i64_556 : i64 | |
%c0_i64_557 = arith.constant 0 : i64 | |
%781 = arith.cmpi sge, %685, %c0_i64_557 : i64 | |
%782 = arith.select %781, %685, %780 : i64 | |
%c0_i64_558 = arith.constant 0 : i64 | |
%783 = arith.cmpi slt, %782, %c0_i64_558 : i64 | |
%784 = arith.select %783, %c0_i64_558, %782 : i64 | |
%785 = arith.cmpi sgt, %784, %c8_i64_556 : i64 | |
%786 = arith.select %785, %c8_i64_556, %784 : i64 | |
%787 = arith.index_cast %786 : i64 to index | |
%788 = arith.index_cast %778 : i64 to index | |
%789 = arith.cmpi slt, %788, %c0_551 : index | |
%790 = arith.addi %788, %c8_555 : index | |
%791 = arith.select %789, %790, %788 : index | |
%792 = arith.cmpi slt, %791, %c0_551 : index | |
%793 = arith.select %792, %c-1_553, %791 : index | |
%794 = arith.cmpi sgt, %793, %c8_555 : index | |
%795 = arith.select %794, %c8_555, %793 : index | |
%c0_559 = arith.constant 0 : index | |
%c8_560 = arith.constant 8 : index | |
%796 = arith.subi %795, %787 : index | |
%797 = arith.cmpi sge, %779, %c0_551 : index | |
%798 = arith.select %797, %c1_552, %c-1_553 : index | |
%799 = arith.addi %796, %779 : index | |
%800 = arith.subi %799, %798 : index | |
%801 = arith.floordivsi %800, %779 : index | |
%802 = arith.cmpi slt, %801, %c0_551 : index | |
%803 = arith.select %802, %c0_551, %801 : index | |
%cast_561 = tensor.cast %9 : tensor<1xi64> to tensor<?xi64> | |
%inserted_slice_562 = tensor.insert_slice %cast_561 into %682[%787] [%803] [%779] : tensor<?xi64> into tensor<8xi64> | |
%cast_563 = tensor.cast %inserted_slice_562 : tensor<8xi64> to tensor<8xi64> | |
%804 = torch_c.from_builtin_tensor %cast_563 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
torch.prim.If.yield %804 : !torch.vtensor<[8],si64> | |
} else { | |
torch.prim.If.yield %arg16 : !torch.vtensor<[8],si64> | |
} | |
torch.prim.If.yield %776 : !torch.vtensor<[8],si64> | |
} else { | |
torch.prim.If.yield %arg16 : !torch.vtensor<[8],si64> | |
} | |
torch.prim.Loop.condition %true, iter(%687 : !torch.vtensor<[8],si64>) | |
} : (!torch.int, !torch.bool, !torch.vtensor<[8],si64>) -> !torch.vtensor<[8],si64> | |
torch.prim.If.yield %681, %517, %490 : !torch.vtensor<[8],si64>, !torch.vtensor<[8],si64>, !torch.int | |
} else { | |
torch.prim.If.yield %arg12, %arg13, %arg14 : !torch.vtensor<[8],si64>, !torch.vtensor<[8],si64>, !torch.int | |
} | |
torch.prim.Loop.condition %true, iter(%445#0, %445#1, %445#2 : !torch.vtensor<[8],si64>, !torch.vtensor<[8],si64>, !torch.int) | |
} : (!torch.int, !torch.bool, !torch.vtensor<[8],si64>, !torch.vtensor<[8],si64>, !torch.int) -> (!torch.vtensor<[8],si64>, !torch.vtensor<[8],si64>, !torch.int) | |
%248 = torch_c.to_i64 %247#2 | |
%249 = torch_c.to_builtin_tensor %247#1 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%c0_187 = arith.constant 0 : index | |
%c1_188 = arith.constant 1 : index | |
%c-1_189 = arith.constant -1 : index | |
%c0_190 = arith.constant 0 : index | |
%c8_191 = arith.constant 8 : index | |
%250 = arith.index_cast %c1_i64 : i64 to index | |
%c8_i64_192 = arith.constant 8 : i64 | |
%251 = arith.addi %c0_i64, %c8_i64_192 : i64 | |
%c0_i64_193 = arith.constant 0 : i64 | |
%252 = arith.cmpi sge, %c0_i64, %c0_i64_193 : i64 | |
%253 = arith.select %252, %c0_i64, %251 : i64 | |
%c0_i64_194 = arith.constant 0 : i64 | |
%254 = arith.cmpi slt, %253, %c0_i64_194 : i64 | |
%255 = arith.select %254, %c0_i64_194, %253 : i64 | |
%256 = arith.cmpi sgt, %255, %c8_i64_192 : i64 | |
%257 = arith.select %256, %c8_i64_192, %255 : i64 | |
%258 = arith.index_cast %257 : i64 to index | |
%259 = arith.index_cast %248 : i64 to index | |
%260 = arith.cmpi slt, %259, %c0_187 : index | |
%261 = arith.addi %259, %c8_191 : index | |
%262 = arith.select %260, %261, %259 : index | |
%263 = arith.cmpi slt, %262, %c0_187 : index | |
%264 = arith.select %263, %c-1_189, %262 : index | |
%265 = arith.cmpi sgt, %264, %c8_191 : index | |
%266 = arith.select %265, %c8_191, %264 : index | |
%c0_195 = arith.constant 0 : index | |
%c8_196 = arith.constant 8 : index | |
%267 = arith.subi %266, %258 : index | |
%268 = arith.cmpi sge, %250, %c0_187 : index | |
%269 = arith.select %268, %c1_188, %c-1_189 : index | |
%270 = arith.addi %267, %250 : index | |
%271 = arith.subi %270, %269 : index | |
%272 = arith.floordivsi %271, %250 : index | |
%273 = arith.cmpi slt, %272, %c0_187 : index | |
%274 = arith.select %273, %c0_187, %272 : index | |
%c1_197 = arith.constant 1 : index | |
%c0_198 = arith.constant 0 : index | |
%c8_199 = arith.constant 8 : index | |
%275 = arith.subi %c8_199, %c1_197 : index | |
%c0_200 = arith.constant 0 : index | |
%c8_201 = arith.constant 8 : index | |
%276 = tensor.empty() : tensor<8xi64> | |
%c0_i64_202 = arith.constant 0 : i64 | |
%277 = linalg.fill ins(%c0_i64_202 : i64) outs(%276 : tensor<8xi64>) -> tensor<8xi64> | |
%278 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%249 : tensor<8xi64>) outs(%277 : tensor<8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%395 = linalg.index 0 : index | |
%396 = arith.subi %275, %395 : index | |
%extracted_283 = tensor.extract %249[%396] : tensor<8xi64> | |
linalg.yield %extracted_283 : i64 | |
} -> tensor<8xi64> | |
%c0_203 = arith.constant 0 : index | |
%c0_204 = arith.constant 0 : index | |
%279 = arith.cmpi slt, %250, %c0_204 : index | |
%280 = math.absi %250 : index | |
%281 = arith.muli %274, %280 : index | |
%dim_205 = tensor.dim %249, %c0_203 : tensor<8xi64> | |
%282 = arith.subi %dim_205, %281 : index | |
%283 = arith.select %279, %282, %258 : index | |
%284 = arith.select %279, %278, %249 : tensor<8xi64> | |
%extracted_slice_206 = tensor.extract_slice %284[%283] [%274] [%280] : tensor<8xi64> to tensor<?xi64> | |
%cast_207 = tensor.cast %extracted_slice_206 : tensor<?xi64> to tensor<?xi64> | |
%285 = torch_c.from_builtin_tensor %cast_207 : tensor<?xi64> -> !torch.vtensor<[?],si64> | |
%c1_i64_208 = arith.constant 1 : i64 | |
%286 = arith.addi %c0_i64, %c1_i64_208 : i64 | |
%c0_i64_209 = arith.constant 0 : i64 | |
%287 = arith.cmpi sge, %c0_i64, %c0_i64_209 : i64 | |
%288 = arith.select %287, %c0_i64, %286 : i64 | |
%289 = arith.index_cast %288 : i64 to index | |
%dim_210 = tensor.dim %cast_207, %289 : tensor<?xi64> | |
%290 = arith.index_cast %dim_210 : index to i64 | |
%291 = torch_c.from_i64 %290 | |
%292 = torch.aten.gt.int %291, %22 : !torch.int, !torch.int -> !torch.bool | |
%293 = torch.prim.If %292 -> (!torch.vtensor<[?],si64>) { | |
%c0_283 = arith.constant 0 : index | |
%c1_284 = arith.constant 1 : index | |
%c-1_285 = arith.constant -1 : index | |
%c0_286 = arith.constant 0 : index | |
%dim_287 = tensor.dim %extracted_slice_206, %c0_286 : tensor<?xi64> | |
%395 = arith.index_cast %c1_i64 : i64 to index | |
%396 = arith.index_cast %dim_287 : index to i64 | |
%397 = arith.addi %c0_i64, %396 : i64 | |
%c0_i64_288 = arith.constant 0 : i64 | |
%398 = arith.cmpi sge, %c0_i64, %c0_i64_288 : i64 | |
%399 = arith.select %398, %c0_i64, %397 : i64 | |
%c0_i64_289 = arith.constant 0 : i64 | |
%400 = arith.cmpi slt, %399, %c0_i64_289 : i64 | |
%401 = arith.select %400, %c0_i64_289, %399 : i64 | |
%402 = arith.cmpi sgt, %401, %396 : i64 | |
%403 = arith.select %402, %396, %401 : i64 | |
%404 = arith.index_cast %403 : i64 to index | |
%405 = arith.index_cast %23 : i64 to index | |
%406 = arith.cmpi slt, %405, %c0_283 : index | |
%407 = arith.addi %405, %dim_287 : index | |
%408 = arith.select %406, %407, %405 : index | |
%409 = arith.cmpi slt, %408, %c0_283 : index | |
%410 = arith.select %409, %c-1_285, %408 : index | |
%411 = arith.cmpi sgt, %410, %dim_287 : index | |
%412 = arith.select %411, %dim_287, %410 : index | |
%c0_290 = arith.constant 0 : index | |
%dim_291 = tensor.dim %extracted_slice_206, %c0_290 : tensor<?xi64> | |
%413 = arith.subi %412, %404 : index | |
%414 = arith.cmpi sge, %395, %c0_283 : index | |
%415 = arith.select %414, %c1_284, %c-1_285 : index | |
%416 = arith.addi %413, %395 : index | |
%417 = arith.subi %416, %415 : index | |
%418 = arith.floordivsi %417, %395 : index | |
%419 = arith.cmpi slt, %418, %c0_283 : index | |
%420 = arith.select %419, %c0_283, %418 : index | |
%c1_292 = arith.constant 1 : index | |
%c0_293 = arith.constant 0 : index | |
%dim_294 = tensor.dim %extracted_slice_206, %c0_293 : tensor<?xi64> | |
%421 = arith.subi %dim_294, %c1_292 : index | |
%c0_295 = arith.constant 0 : index | |
%dim_296 = tensor.dim %extracted_slice_206, %c0_295 : tensor<?xi64> | |
%422 = tensor.empty(%dim_296) : tensor<?xi64> | |
%c0_i64_297 = arith.constant 0 : i64 | |
%423 = linalg.fill ins(%c0_i64_297 : i64) outs(%422 : tensor<?xi64>) -> tensor<?xi64> | |
%424 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_207 : tensor<?xi64>) outs(%423 : tensor<?xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%432 = linalg.index 0 : index | |
%433 = arith.subi %421, %432 : index | |
%extracted_303 = tensor.extract %cast_207[%433] : tensor<?xi64> | |
linalg.yield %extracted_303 : i64 | |
} -> tensor<?xi64> | |
%c0_298 = arith.constant 0 : index | |
%c0_299 = arith.constant 0 : index | |
%425 = arith.cmpi slt, %395, %c0_299 : index | |
%426 = math.absi %395 : index | |
%427 = arith.muli %420, %426 : index | |
%dim_300 = tensor.dim %cast_207, %c0_298 : tensor<?xi64> | |
%428 = arith.subi %dim_300, %427 : index | |
%429 = arith.select %425, %428, %404 : index | |
%430 = arith.select %425, %424, %cast_207 : tensor<?xi64> | |
%extracted_slice_301 = tensor.extract_slice %430[%429] [%420] [%426] : tensor<?xi64> to tensor<?xi64> | |
%cast_302 = tensor.cast %extracted_slice_301 : tensor<?xi64> to tensor<?xi64> | |
%431 = torch_c.from_builtin_tensor %cast_302 : tensor<?xi64> -> !torch.vtensor<[?],si64> | |
torch.prim.If.yield %431 : !torch.vtensor<[?],si64> | |
} else { | |
torch.prim.If.yield %285 : !torch.vtensor<[?],si64> | |
} | |
%294 = torch_c.to_builtin_tensor %293 : !torch.vtensor<[?],si64> -> tensor<?xi64> | |
%c0_211 = arith.constant 0 : index | |
%dim_212 = tensor.dim %294, %c0_211 : tensor<?xi64> | |
%c1_213 = arith.constant 1 : index | |
%expanded = tensor.expand_shape %294 [[0, 1]] output_shape [%dim_212, 1] : tensor<?xi64> into tensor<?x1xi64> | |
%295 = torch_c.from_builtin_tensor %expanded : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64> | |
%c2_i64_214 = arith.constant 2 : i64 | |
%296 = arith.addi %c0_i64, %c2_i64_214 : i64 | |
%c0_i64_215 = arith.constant 0 : i64 | |
%297 = arith.cmpi sge, %c0_i64, %c0_i64_215 : i64 | |
%298 = arith.select %297, %c0_i64, %296 : i64 | |
%299 = arith.index_cast %298 : i64 to index | |
%dim_216 = tensor.dim %expanded, %299 : tensor<?x1xi64> | |
%300 = arith.index_cast %dim_216 : index to i64 | |
%301 = torch_c.from_i64 %300 | |
%302 = torch.prim.ListConstruct %301, %int1 : (!torch.int, !torch.int) -> !torch.list<int> | |
%303 = torch_c.to_i64 %301 | |
%c1_i64_217 = arith.constant 1 : i64 | |
%304 = arith.index_cast %303 : i64 to index | |
%c1_218 = arith.constant 1 : index | |
%305 = tensor.empty(%304) : tensor<?x1xi64> | |
%cast_219 = tensor.cast %305 : tensor<?x1xi64> to tensor<?x1xi64> | |
%c1_220 = arith.constant 1 : index | |
%c0_221 = arith.constant 0 : index | |
%dim_222 = tensor.dim %305, %c0_221 : tensor<?x1xi64> | |
%306 = tensor.empty(%dim_222) : tensor<?x1xi64> | |
%307 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, 0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_219 : tensor<?x1xi64>) outs(%306 : tensor<?x1xi64>) { | |
^bb0(%in: i64, %out: i64): | |
linalg.yield %31 : i64 | |
} -> tensor<?x1xi64> | |
%cast_223 = tensor.cast %307 : tensor<?x1xi64> to tensor<?x1xi64> | |
%308 = torch_c.from_builtin_tensor %cast_223 : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64> | |
%309 = torch_c.to_i64 %301 | |
%c1_i64_224 = arith.constant 1 : i64 | |
%310 = arith.index_cast %309 : i64 to index | |
%c1_225 = arith.constant 1 : index | |
%311 = tensor.empty(%310) : tensor<?x1xi64> | |
%cast_226 = tensor.cast %311 : tensor<?x1xi64> to tensor<?x1xi64> | |
%c1_227 = arith.constant 1 : index | |
%c0_228 = arith.constant 0 : index | |
%dim_229 = tensor.dim %311, %c0_228 : tensor<?x1xi64> | |
%312 = tensor.empty(%dim_229) : tensor<?x1xi64> | |
%313 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, 0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_226 : tensor<?x1xi64>) outs(%312 : tensor<?x1xi64>) { | |
^bb0(%in: i64, %out: i64): | |
linalg.yield %77 : i64 | |
} -> tensor<?x1xi64> | |
%cast_230 = tensor.cast %313 : tensor<?x1xi64> to tensor<?x1xi64> | |
%314 = torch_c.from_builtin_tensor %cast_230 : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64> | |
%315 = torch.prim.ListConstruct %314, %295 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[?,1],si64>) -> !torch.list<vtensor> | |
%316 = torch_c.to_builtin_tensor %314 : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64> | |
%317 = torch_c.to_builtin_tensor %295 : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64> | |
%concat = tensor.concat dim(1) %316, %317 : (tensor<?x1xi64>, tensor<?x1xi64>) -> tensor<?x2xi64> | |
%318 = torch_c.from_builtin_tensor %concat : tensor<?x2xi64> -> !torch.vtensor<[?,2],si64> | |
%319 = torch.prim.ListConstruct %308, %318 : (!torch.vtensor<[?,1],si64>, !torch.vtensor<[?,2],si64>) -> !torch.list<vtensor> | |
%320 = torch_c.to_builtin_tensor %308 : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64> | |
%321 = torch_c.to_builtin_tensor %318 : !torch.vtensor<[?,2],si64> -> tensor<?x2xi64> | |
%concat_231 = tensor.concat dim(1) %320, %321 : (tensor<?x1xi64>, tensor<?x2xi64>) -> tensor<?x3xi64> | |
%322 = torch.aten.add.int %arg10, %301 : !torch.int, !torch.int -> !torch.int | |
%323 = torch_c.to_i64 %322 | |
%c0_232 = arith.constant 0 : index | |
%c1_233 = arith.constant 1 : index | |
%c-1_234 = arith.constant -1 : index | |
%c0_235 = arith.constant 0 : index | |
%c120 = arith.constant 120 : index | |
%c1_236 = arith.constant 1 : index | |
%c3_237 = arith.constant 3 : index | |
%324 = arith.index_cast %c1_i64 : i64 to index | |
%c120_i64 = arith.constant 120 : i64 | |
%325 = arith.addi %75, %c120_i64 : i64 | |
%c0_i64_238 = arith.constant 0 : i64 | |
%326 = arith.cmpi sge, %75, %c0_i64_238 : i64 | |
%327 = arith.select %326, %75, %325 : i64 | |
%c0_i64_239 = arith.constant 0 : i64 | |
%328 = arith.cmpi slt, %327, %c0_i64_239 : i64 | |
%329 = arith.select %328, %c0_i64_239, %327 : i64 | |
%330 = arith.cmpi sgt, %329, %c120_i64 : i64 | |
%331 = arith.select %330, %c120_i64, %329 : i64 | |
%332 = arith.index_cast %331 : i64 to index | |
%333 = arith.index_cast %323 : i64 to index | |
%334 = arith.cmpi slt, %333, %c0_232 : index | |
%335 = arith.addi %333, %c120 : index | |
%336 = arith.select %334, %335, %333 : index | |
%337 = arith.cmpi slt, %336, %c0_232 : index | |
%338 = arith.select %337, %c-1_234, %336 : index | |
%339 = arith.cmpi sgt, %338, %c120 : index | |
%340 = arith.select %339, %c120, %338 : index | |
%c0_240 = arith.constant 0 : index | |
%c120_241 = arith.constant 120 : index | |
%c1_242 = arith.constant 1 : index | |
%c3_243 = arith.constant 3 : index | |
%341 = arith.subi %340, %332 : index | |
%342 = arith.cmpi sge, %324, %c0_232 : index | |
%343 = arith.select %342, %c1_233, %c-1_234 : index | |
%344 = arith.addi %341, %324 : index | |
%345 = arith.subi %344, %343 : index | |
%346 = arith.floordivsi %345, %324 : index | |
%347 = arith.cmpi slt, %346, %c0_232 : index | |
%348 = arith.select %347, %c0_232, %346 : index | |
%c1_244 = arith.constant 1 : index | |
%c0_245 = arith.constant 0 : index | |
%c120_246 = arith.constant 120 : index | |
%c1_247 = arith.constant 1 : index | |
%c3_248 = arith.constant 3 : index | |
%349 = arith.subi %c120_246, %c1_244 : index | |
%c0_249 = arith.constant 0 : index | |
%c120_250 = arith.constant 120 : index | |
%c1_251 = arith.constant 1 : index | |
%c3_252 = arith.constant 3 : index | |
%350 = tensor.empty() : tensor<120x3xi64> | |
%c0_i64_253 = arith.constant 0 : i64 | |
%351 = linalg.fill ins(%c0_i64_253 : i64) outs(%350 : tensor<120x3xi64>) -> tensor<120x3xi64> | |
%352 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%76 : tensor<120x3xi64>) outs(%351 : tensor<120x3xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%395 = linalg.index 0 : index | |
%396 = linalg.index 1 : index | |
%397 = arith.subi %349, %395 : index | |
%extracted_283 = tensor.extract %76[%397, %396] : tensor<120x3xi64> | |
linalg.yield %extracted_283 : i64 | |
} -> tensor<120x3xi64> | |
%c0_254 = arith.constant 0 : index | |
%c0_255 = arith.constant 0 : index | |
%353 = arith.cmpi slt, %324, %c0_255 : index | |
%354 = math.absi %324 : index | |
%355 = arith.muli %348, %354 : index | |
%dim_256 = tensor.dim %76, %c0_254 : tensor<120x3xi64> | |
%356 = arith.subi %dim_256, %355 : index | |
%357 = arith.select %353, %356, %332 : index | |
%358 = arith.select %353, %352, %76 : tensor<120x3xi64> | |
%extracted_slice_257 = tensor.extract_slice %358[%357, %c0_232] [%348, %c3_243] [%354, %c1_233] : tensor<120x3xi64> to tensor<?x?xi64> | |
%cast_258 = tensor.cast %extracted_slice_257 : tensor<?x?xi64> to tensor<?x3xi64> | |
%c2_i64_259 = arith.constant 2 : i64 | |
%359 = arith.addi %c0_i64, %c2_i64_259 : i64 | |
%c0_i64_260 = arith.constant 0 : i64 | |
%360 = arith.cmpi sge, %c0_i64, %c0_i64_260 : i64 | |
%361 = arith.select %360, %c0_i64, %359 : i64 | |
%362 = arith.index_cast %361 : i64 to index | |
%dim_261 = tensor.dim %cast_258, %362 : tensor<?x3xi64> | |
%363 = arith.index_cast %dim_261 : index to i64 | |
%364 = torch_c.from_i64 %363 | |
%365 = torch.prim.ListConstruct %364, %int3 : (!torch.int, !torch.int) -> !torch.list<int> | |
%366 = torch_c.to_i64 %364 | |
%c3_i64_262 = arith.constant 3 : i64 | |
%c0_i64_263 = arith.constant 0 : i64 | |
%c0_264 = arith.constant 0 : index | |
%c1_265 = arith.constant 1 : index | |
%367 = arith.index_cast %366 : i64 to index | |
%368 = tensor.empty(%367) : tensor<?x3xi64> | |
%cast_266 = tensor.cast %concat_231 : tensor<?x3xi64> to tensor<?x3xi64> | |
%c0_267 = arith.constant 0 : index | |
%c1_268 = arith.constant 1 : index | |
%c-1_269 = arith.constant -1 : index | |
%c0_270 = arith.constant 0 : index | |
%c120_271 = arith.constant 120 : index | |
%c1_272 = arith.constant 1 : index | |
%c3_273 = arith.constant 3 : index | |
%369 = arith.index_cast %c1_i64 : i64 to index | |
%c120_i64_274 = arith.constant 120 : i64 | |
%370 = arith.addi %75, %c120_i64_274 : i64 | |
%c0_i64_275 = arith.constant 0 : i64 | |
%371 = arith.cmpi sge, %75, %c0_i64_275 : i64 | |
%372 = arith.select %371, %75, %370 : i64 | |
%c0_i64_276 = arith.constant 0 : i64 | |
%373 = arith.cmpi slt, %372, %c0_i64_276 : i64 | |
%374 = arith.select %373, %c0_i64_276, %372 : i64 | |
%375 = arith.cmpi sgt, %374, %c120_i64_274 : i64 | |
%376 = arith.select %375, %c120_i64_274, %374 : i64 | |
%377 = arith.index_cast %376 : i64 to index | |
%378 = arith.index_cast %323 : i64 to index | |
%379 = arith.cmpi slt, %378, %c0_267 : index | |
%380 = arith.addi %378, %c120_271 : index | |
%381 = arith.select %379, %380, %378 : index | |
%382 = arith.cmpi slt, %381, %c0_267 : index | |
%383 = arith.select %382, %c-1_269, %381 : index | |
%384 = arith.cmpi sgt, %383, %c120_271 : index | |
%385 = arith.select %384, %c120_271, %383 : index | |
%c0_277 = arith.constant 0 : index | |
%c120_278 = arith.constant 120 : index | |
%c1_279 = arith.constant 1 : index | |
%c3_280 = arith.constant 3 : index | |
%386 = arith.subi %385, %377 : index | |
%387 = arith.cmpi sge, %369, %c0_267 : index | |
%388 = arith.select %387, %c1_268, %c-1_269 : index | |
%389 = arith.addi %386, %369 : index | |
%390 = arith.subi %389, %388 : index | |
%391 = arith.floordivsi %390, %369 : index | |
%392 = arith.cmpi slt, %391, %c0_267 : index | |
%393 = arith.select %392, %c0_267, %391 : index | |
%cast_281 = tensor.cast %cast_266 : tensor<?x3xi64> to tensor<?x?xi64> | |
%inserted_slice = tensor.insert_slice %cast_281 into %76[%377, %c0_267] [%393, %c3_280] [%369, %c1_268] : tensor<?x?xi64> into tensor<120x3xi64> | |
%cast_282 = tensor.cast %inserted_slice : tensor<120x3xi64> to tensor<120x3xi64> | |
%394 = torch_c.from_builtin_tensor %cast_282 : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
torch.prim.Loop.condition %true, iter(%394, %322 : !torch.vtensor<[120,3],si64>, !torch.int) | |
} : (!torch.int, !torch.bool, !torch.vtensor<[120,3],si64>, !torch.int) -> (!torch.vtensor<[120,3],si64>, !torch.int) | |
torch.prim.Loop.condition %true, iter(%74#0, %74#1 : !torch.vtensor<[120,3],si64>, !torch.int) | |
} : (!torch.int, !torch.bool, !torch.vtensor<[120,3],si64>, !torch.int) -> (!torch.vtensor<[120,3],si64>, !torch.int) | |
return %30#0 : !torch.vtensor<[120,3],si64> | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
func.func @main(%arg0: !torch.vtensor<[3,8,4],f32>, %arg1: !torch.vtensor<[3,5,8],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[120,3],si64> attributes {torch.assume_strict_symbolic_shapes, torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { | |
%0 = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[3,8,4],f32> -> tensor<3x8x4xf32> | |
%1 = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[3,5,8],f32> -> tensor<3x5x8xf32> | |
%2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%3 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%4 = torch_c.to_builtin_tensor %arg4 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%5 = torch.vtensor.literal(dense<1> : tensor<8xsi64>) : !torch.vtensor<[8],si64> | |
%6 = torch_c.to_builtin_tensor %5 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%7 = torch.vtensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.vtensor<[1],f32> | |
%8 = torch_c.to_builtin_tensor %7 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%9 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> | |
%10 = torch_c.to_builtin_tensor %9 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%int8 = torch.constant.int 8 | |
%c8_i64 = arith.constant 8 : i64 | |
%c2_i64 = arith.constant 2 : i64 | |
%int15 = torch.constant.int 15 | |
%int5 = torch.constant.int 5 | |
%c5_i64 = arith.constant 5 : i64 | |
%int0 = torch.constant.int 0 | |
%c0_i64 = arith.constant 0 : i64 | |
%c0_i64_0 = arith.constant 0 : i64 | |
%int1 = torch.constant.int 1 | |
%c1_i64 = arith.constant 1 : i64 | |
%int3 = torch.constant.int 3 | |
%c3_i64 = arith.constant 3 : i64 | |
%c4_i64 = arith.constant 4 : i64 | |
%true = torch.constant.bool true | |
%c0 = arith.constant 0 : index | |
%extracted = tensor.extract %4[%c0] : tensor<1xf32> | |
%11 = arith.extf %extracted : f32 to f64 | |
%12 = torch_c.from_f64 %11 | |
%cst = arith.constant 0x7F800000 : f32 | |
%c1 = arith.constant 1 : index | |
%c2 = arith.constant 2 : index | |
%13 = tensor.empty() : tensor<f32> | |
%14 = linalg.fill ins(%cst : f32) outs(%13 : tensor<f32>) -> tensor<f32> | |
%15 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> ()>], iterator_types = ["reduction", "reduction", "reduction"]} ins(%1 : tensor<3x5x8xf32>) outs(%14 : tensor<f32>) { | |
^bb0(%in: f32, %out: f32): | |
%34 = arith.minimumf %in, %out : f32 | |
linalg.yield %34 : f32 | |
} -> tensor<f32> | |
%16 = torch_c.from_builtin_tensor %15 : tensor<f32> -> !torch.vtensor<[],f32> | |
%17 = torch_c.to_builtin_tensor %16 : !torch.vtensor<[],f32> -> tensor<f32> | |
%extracted_1 = tensor.extract %17[] : tensor<f32> | |
%18 = arith.extf %extracted_1 : f32 to f64 | |
%19 = torch_c.from_f64 %18 | |
%20 = torch.aten.ge.float %19, %12 : !torch.float, !torch.float -> !torch.bool | |
torch.runtime.assert %20, "unimplemented: score_threshold should be <= min(scores)" | |
%extracted_2 = tensor.extract %3[%c0] : tensor<1xf32> | |
%21 = arith.extf %extracted_2 : f32 to f64 | |
%22 = torch_c.from_f64 %21 | |
%extracted_3 = tensor.extract %2[%c0] : tensor<1xi64> | |
%23 = torch_c.from_i64 %extracted_3 | |
%24 = torch_c.to_i64 %23 | |
%25 = torch.aten.mul.int %int15, %23 : !torch.int, !torch.int -> !torch.int | |
%26 = torch_c.to_i64 %25 | |
%27 = arith.index_cast %26 : i64 to index | |
%28 = tensor.empty(%27) : tensor<?x3xi64> | |
%cast = tensor.cast %28 : tensor<?x3xi64> to tensor<120x3xi64> | |
%29 = torch_c.from_builtin_tensor %cast : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
%30 = torch_c.to_builtin_tensor %29 : !torch.vtensor<[120,3],si64> -> tensor<120x3xi64> | |
%c0_4 = arith.constant 0 : index | |
%c1_5 = arith.constant 1 : index | |
%31 = arith.index_cast %c3_i64 : i64 to index | |
%32:2 = scf.for %arg5 = %c0_4 to %31 step %c1_5 iter_args(%arg6 = %30, %arg7 = %c0_i64) -> (tensor<120x3xi64>, i64) { | |
%34 = arith.index_cast %arg5 : index to i64 | |
%35 = torch_c.from_i64 %34 | |
%36 = torch_c.from_builtin_tensor %arg6 : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
%37 = torch_c.to_builtin_tensor %36 : !torch.vtensor<[120,3],si64> -> tensor<120x3xi64> | |
%38 = torch_c.from_i64 %arg7 | |
%39 = torch_c.to_i64 %38 | |
%40 = torch_c.to_i64 %35 | |
%41 = torch.aten.lt.int %35, %int0 : !torch.int, !torch.int -> !torch.bool | |
%42 = torch.aten.Int.bool %41 : !torch.bool -> !torch.int | |
%43 = torch.aten.mul.int %42, %int3 : !torch.int, !torch.int -> !torch.int | |
%44 = torch.aten.add.int %35, %43 : !torch.int, !torch.int -> !torch.int | |
%45 = torch_c.to_i64 %44 | |
%46 = torch.aten.add.int %44, %int1 : !torch.int, !torch.int -> !torch.int | |
%47 = torch_c.to_i64 %46 | |
%c-1 = arith.constant -1 : index | |
%c3 = arith.constant 3 : index | |
%48 = arith.index_cast %c1_i64 : i64 to index | |
%c3_i64_6 = arith.constant 3 : i64 | |
%49 = arith.addi %45, %c3_i64_6 : i64 | |
%50 = arith.cmpi sge, %45, %c0_i64_0 : i64 | |
%51 = arith.select %50, %45, %49 : i64 | |
%52 = arith.cmpi slt, %51, %c0_i64_0 : i64 | |
%53 = arith.select %52, %c0_i64_0, %51 : i64 | |
%54 = arith.cmpi sgt, %53, %c3_i64_6 : i64 | |
%55 = arith.select %54, %c3_i64_6, %53 : i64 | |
%56 = arith.index_cast %55 : i64 to index | |
%57 = arith.index_cast %47 : i64 to index | |
%58 = arith.cmpi slt, %57, %c0 : index | |
%59 = arith.addi %57, %c3 : index | |
%60 = arith.select %58, %59, %57 : index | |
%61 = arith.cmpi slt, %60, %c0 : index | |
%62 = arith.select %61, %c-1, %60 : index | |
%63 = arith.cmpi sgt, %62, %c3 : index | |
%64 = arith.select %63, %c3, %62 : index | |
%c8 = arith.constant 8 : index | |
%c4 = arith.constant 4 : index | |
%65 = arith.subi %64, %56 : index | |
%66 = arith.cmpi sge, %48, %c0 : index | |
%67 = arith.select %66, %c1, %c-1 : index | |
%68 = arith.addi %65, %48 : index | |
%69 = arith.subi %68, %67 : index | |
%70 = arith.floordivsi %69, %48 : index | |
%71 = arith.cmpi slt, %70, %c0 : index | |
%72 = arith.select %71, %c0, %70 : index | |
%73 = arith.subi %c3, %c1 : index | |
%74 = tensor.empty() : tensor<3x8x4xf32> | |
%cst_7 = arith.constant 0.000000e+00 : f32 | |
%75 = linalg.fill ins(%cst_7 : f32) outs(%74 : tensor<3x8x4xf32>) -> tensor<3x8x4xf32> | |
%76 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%0 : tensor<3x8x4xf32>) outs(%75 : tensor<3x8x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%89 = linalg.index 0 : index | |
%90 = linalg.index 1 : index | |
%91 = linalg.index 2 : index | |
%92 = arith.subi %73, %89 : index | |
%extracted_12 = tensor.extract %0[%92, %90, %91] : tensor<3x8x4xf32> | |
linalg.yield %extracted_12 : f32 | |
} -> tensor<3x8x4xf32> | |
%77 = arith.cmpi slt, %48, %c0 : index | |
%78 = math.absi %48 : index | |
%79 = arith.muli %72, %78 : index | |
%c3_8 = arith.constant 3 : index | |
%80 = arith.subi %c3_8, %79 : index | |
%81 = arith.select %77, %80, %56 : index | |
%82 = arith.select %77, %76, %0 : tensor<3x8x4xf32> | |
%extracted_slice = tensor.extract_slice %82[%81, %c0, %c0] [%72, %c8, %c4] [%78, %c1, %c1] : tensor<3x8x4xf32> to tensor<?x?x?xf32> | |
%cast_9 = tensor.cast %extracted_slice : tensor<?x?x?xf32> to tensor<1x8x4xf32> | |
%collapsed = tensor.collapse_shape %cast_9 [[0, 1], [2]] : tensor<1x8x4xf32> into tensor<8x4xf32> | |
%c0_10 = arith.constant 0 : index | |
%c1_11 = arith.constant 1 : index | |
%83 = arith.index_cast %c5_i64 : i64 to index | |
%84:2 = scf.for %arg8 = %c0_10 to %83 step %c1_11 iter_args(%arg9 = %37, %arg10 = %39) -> (tensor<120x3xi64>, i64) { | |
%89 = arith.index_cast %arg8 : index to i64 | |
%90 = torch_c.from_i64 %89 | |
%91 = torch_c.from_builtin_tensor %arg9 : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
%92 = torch_c.from_i64 %arg10 | |
%93 = torch_c.to_i64 %92 | |
%94 = torch_c.to_builtin_tensor %91 : !torch.vtensor<[120,3],si64> -> tensor<120x3xi64> | |
%95 = torch_c.to_i64 %90 | |
%96 = torch.aten.lt.int %35, %int0 : !torch.int, !torch.int -> !torch.bool | |
%97 = torch.aten.Int.bool %96 : !torch.bool -> !torch.int | |
%98 = torch.aten.mul.int %97, %int3 : !torch.int, !torch.int -> !torch.int | |
%99 = torch.aten.add.int %35, %98 : !torch.int, !torch.int -> !torch.int | |
%100 = torch_c.to_i64 %99 | |
%101 = torch.aten.add.int %99, %int1 : !torch.int, !torch.int -> !torch.int | |
%102 = torch_c.to_i64 %101 | |
%103 = arith.addi %100, %c3_i64_6 : i64 | |
%104 = arith.cmpi sge, %100, %c0_i64_0 : i64 | |
%105 = arith.select %104, %100, %103 : i64 | |
%106 = arith.cmpi slt, %105, %c0_i64_0 : i64 | |
%107 = arith.select %106, %c0_i64_0, %105 : i64 | |
%108 = arith.cmpi sgt, %107, %c3_i64_6 : i64 | |
%109 = arith.select %108, %c3_i64_6, %107 : i64 | |
%110 = arith.index_cast %109 : i64 to index | |
%111 = arith.index_cast %102 : i64 to index | |
%112 = arith.cmpi slt, %111, %c0 : index | |
%113 = arith.addi %111, %c3 : index | |
%114 = arith.select %112, %113, %111 : index | |
%115 = arith.cmpi slt, %114, %c0 : index | |
%116 = arith.select %115, %c-1, %114 : index | |
%117 = arith.cmpi sgt, %116, %c3 : index | |
%118 = arith.select %117, %c3, %116 : index | |
%c5 = arith.constant 5 : index | |
%119 = arith.subi %118, %110 : index | |
%120 = arith.addi %119, %48 : index | |
%121 = arith.subi %120, %67 : index | |
%122 = arith.floordivsi %121, %48 : index | |
%123 = arith.cmpi slt, %122, %c0 : index | |
%124 = arith.select %123, %c0, %122 : index | |
%125 = tensor.empty() : tensor<3x5x8xf32> | |
%126 = linalg.fill ins(%cst_7 : f32) outs(%125 : tensor<3x5x8xf32>) -> tensor<3x5x8xf32> | |
%127 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d1, d2)>, affine_map<(d0, d1, d2) -> (d0, d1, d2)>], iterator_types = ["parallel", "parallel", "parallel"]} ins(%1 : tensor<3x5x8xf32>) outs(%126 : tensor<3x5x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%342 = linalg.index 0 : index | |
%343 = linalg.index 1 : index | |
%344 = linalg.index 2 : index | |
%345 = arith.subi %73, %342 : index | |
%extracted_43 = tensor.extract %1[%345, %343, %344] : tensor<3x5x8xf32> | |
linalg.yield %extracted_43 : f32 | |
} -> tensor<3x5x8xf32> | |
%128 = arith.muli %124, %78 : index | |
%c3_12 = arith.constant 3 : index | |
%129 = arith.subi %c3_12, %128 : index | |
%130 = arith.select %77, %129, %110 : index | |
%131 = arith.select %77, %127, %1 : tensor<3x5x8xf32> | |
%extracted_slice_13 = tensor.extract_slice %131[%130, %c0, %c0] [%124, %c5, %c8] [%78, %c1, %c1] : tensor<3x5x8xf32> to tensor<?x?x?xf32> | |
%cast_14 = tensor.cast %extracted_slice_13 : tensor<?x?x?xf32> to tensor<1x5x8xf32> | |
%collapsed_15 = tensor.collapse_shape %cast_14 [[0, 1], [2]] : tensor<1x5x8xf32> into tensor<5x8xf32> | |
%132 = torch.aten.lt.int %90, %int0 : !torch.int, !torch.int -> !torch.bool | |
%133 = torch.aten.Int.bool %132 : !torch.bool -> !torch.int | |
%134 = torch.aten.mul.int %133, %int5 : !torch.int, !torch.int -> !torch.int | |
%135 = torch.aten.add.int %90, %134 : !torch.int, !torch.int -> !torch.int | |
%136 = torch_c.to_i64 %135 | |
%137 = torch.aten.add.int %135, %int1 : !torch.int, !torch.int -> !torch.int | |
%138 = torch_c.to_i64 %137 | |
%c5_i64_16 = arith.constant 5 : i64 | |
%139 = arith.addi %136, %c5_i64_16 : i64 | |
%140 = arith.cmpi sge, %136, %c0_i64_0 : i64 | |
%141 = arith.select %140, %136, %139 : i64 | |
%142 = arith.cmpi slt, %141, %c0_i64_0 : i64 | |
%143 = arith.select %142, %c0_i64_0, %141 : i64 | |
%144 = arith.cmpi sgt, %143, %c5_i64_16 : i64 | |
%145 = arith.select %144, %c5_i64_16, %143 : i64 | |
%146 = arith.index_cast %145 : i64 to index | |
%147 = arith.index_cast %138 : i64 to index | |
%148 = arith.cmpi slt, %147, %c0 : index | |
%149 = arith.addi %147, %c5 : index | |
%150 = arith.select %148, %149, %147 : index | |
%151 = arith.cmpi slt, %150, %c0 : index | |
%152 = arith.select %151, %c-1, %150 : index | |
%153 = arith.cmpi sgt, %152, %c5 : index | |
%154 = arith.select %153, %c5, %152 : index | |
%155 = arith.subi %154, %146 : index | |
%156 = arith.addi %155, %48 : index | |
%157 = arith.subi %156, %67 : index | |
%158 = arith.floordivsi %157, %48 : index | |
%159 = arith.cmpi slt, %158, %c0 : index | |
%160 = arith.select %159, %c0, %158 : index | |
%161 = arith.subi %c5, %c1 : index | |
%162 = tensor.empty() : tensor<5x8xf32> | |
%163 = linalg.fill ins(%cst_7 : f32) outs(%162 : tensor<5x8xf32>) -> tensor<5x8xf32> | |
%164 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed_15 : tensor<5x8xf32>) outs(%163 : tensor<5x8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%342 = linalg.index 0 : index | |
%343 = linalg.index 1 : index | |
%344 = arith.subi %161, %342 : index | |
%extracted_43 = tensor.extract %collapsed_15[%344, %343] : tensor<5x8xf32> | |
linalg.yield %extracted_43 : f32 | |
} -> tensor<5x8xf32> | |
%165 = arith.muli %160, %78 : index | |
%c5_17 = arith.constant 5 : index | |
%166 = arith.subi %c5_17, %165 : index | |
%167 = arith.select %77, %166, %146 : index | |
%168 = arith.select %77, %164, %collapsed_15 : tensor<5x8xf32> | |
%extracted_slice_18 = tensor.extract_slice %168[%167, %c0] [%160, %c8] [%78, %c1] : tensor<5x8xf32> to tensor<?x?xf32> | |
%cast_19 = tensor.cast %extracted_slice_18 : tensor<?x?xf32> to tensor<1x8xf32> | |
%collapsed_20 = tensor.collapse_shape %cast_19 [[0, 1]] : tensor<1x8xf32> into tensor<8xf32> | |
%169 = torch_c.from_builtin_tensor %collapsed_20 : tensor<8xf32> -> !torch.vtensor<[8],f32> | |
%170 = torch_c.to_builtin_tensor %169 : !torch.vtensor<[8],f32> -> tensor<8xf32> | |
%171 = arith.addi %c0_i64_0, %c4_i64 : i64 | |
%172 = arith.cmpi sge, %c0_i64_0, %c0_i64_0 : i64 | |
%173 = arith.select %172, %c0_i64_0, %171 : i64 | |
%174 = arith.cmpi slt, %173, %c0_i64_0 : i64 | |
%175 = arith.select %174, %c0_i64_0, %173 : i64 | |
%176 = arith.cmpi sgt, %175, %c4_i64 : i64 | |
%177 = arith.select %176, %c4_i64, %175 : i64 | |
%178 = arith.index_cast %177 : i64 to index | |
%179 = arith.index_cast %c2_i64 : i64 to index | |
%180 = arith.cmpi slt, %179, %c0 : index | |
%181 = arith.addi %179, %c4 : index | |
%182 = arith.select %180, %181, %179 : index | |
%183 = arith.cmpi slt, %182, %c0 : index | |
%184 = arith.select %183, %c-1, %182 : index | |
%185 = arith.cmpi sgt, %184, %c4 : index | |
%186 = arith.select %185, %c4, %184 : index | |
%187 = arith.subi %186, %178 : index | |
%188 = arith.addi %187, %48 : index | |
%189 = arith.subi %188, %67 : index | |
%190 = arith.floordivsi %189, %48 : index | |
%191 = arith.cmpi slt, %190, %c0 : index | |
%192 = arith.select %191, %c0, %190 : index | |
%193 = arith.subi %c4, %c1 : index | |
%194 = tensor.empty() : tensor<8x4xf32> | |
%195 = linalg.fill ins(%cst_7 : f32) outs(%194 : tensor<8x4xf32>) -> tensor<8x4xf32> | |
%196 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<8x4xf32>) outs(%195 : tensor<8x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%342 = linalg.index 0 : index | |
%343 = linalg.index 1 : index | |
%344 = arith.subi %193, %343 : index | |
%extracted_43 = tensor.extract %collapsed[%342, %344] : tensor<8x4xf32> | |
linalg.yield %extracted_43 : f32 | |
} -> tensor<8x4xf32> | |
%197 = arith.muli %192, %78 : index | |
%c4_21 = arith.constant 4 : index | |
%198 = arith.subi %c4_21, %197 : index | |
%199 = arith.select %77, %198, %178 : index | |
%200 = arith.select %77, %196, %collapsed : tensor<8x4xf32> | |
%extracted_slice_22 = tensor.extract_slice %200[%c0, %199] [%c8, %192] [%c1, %78] : tensor<8x4xf32> to tensor<?x?xf32> | |
%cast_23 = tensor.cast %extracted_slice_22 : tensor<?x?xf32> to tensor<8x2xf32> | |
%201 = arith.addi %c2_i64, %c4_i64 : i64 | |
%202 = arith.cmpi sge, %c2_i64, %c0_i64_0 : i64 | |
%203 = arith.select %202, %c2_i64, %201 : i64 | |
%204 = arith.cmpi slt, %203, %c0_i64_0 : i64 | |
%205 = arith.select %204, %c0_i64_0, %203 : i64 | |
%206 = arith.cmpi sgt, %205, %c4_i64 : i64 | |
%207 = arith.select %206, %c4_i64, %205 : i64 | |
%208 = arith.index_cast %207 : i64 to index | |
%209 = arith.index_cast %c4_i64 : i64 to index | |
%210 = arith.cmpi slt, %209, %c0 : index | |
%211 = arith.addi %209, %c4 : index | |
%212 = arith.select %210, %211, %209 : index | |
%213 = arith.cmpi slt, %212, %c0 : index | |
%214 = arith.select %213, %c-1, %212 : index | |
%215 = arith.cmpi sgt, %214, %c4 : index | |
%216 = arith.select %215, %c4, %214 : index | |
%217 = arith.subi %216, %208 : index | |
%218 = arith.addi %217, %48 : index | |
%219 = arith.subi %218, %67 : index | |
%220 = arith.floordivsi %219, %48 : index | |
%221 = arith.cmpi slt, %220, %c0 : index | |
%222 = arith.select %221, %c0, %220 : index | |
%223 = arith.muli %222, %78 : index | |
%224 = arith.subi %c4_21, %223 : index | |
%225 = arith.select %77, %224, %208 : index | |
%extracted_slice_24 = tensor.extract_slice %200[%c0, %225] [%c8, %222] [%c1, %78] : tensor<8x4xf32> to tensor<?x?xf32> | |
%cast_25 = tensor.cast %extracted_slice_24 : tensor<?x?xf32> to tensor<8x2xf32> | |
%226 = tensor.empty() : tensor<8x2xf32> | |
%227 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_25, %cast_23 : tensor<8x2xf32>, tensor<8x2xf32>) outs(%226 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_43: f32, %out: f32): | |
%342 = arith.sitofp %c1_i64 : i64 to f32 | |
%343 = arith.mulf %in_43, %342 : f32 | |
%344 = arith.subf %in, %343 : f32 | |
linalg.yield %344 : f32 | |
} -> tensor<8x2xf32> | |
%cst_26 = arith.constant 1.000000e+00 : f32 | |
%c8_27 = arith.constant 8 : index | |
%228 = tensor.empty(%c8_27) : tensor<?xf32> | |
%229 = linalg.fill ins(%cst_26 : f32) outs(%228 : tensor<?xf32>) -> tensor<?xf32> | |
%230 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%227 : tensor<8x2xf32>) outs(%229 : tensor<?xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%342 = arith.mulf %in, %out : f32 | |
linalg.yield %342 : f32 | |
} -> tensor<?xf32> | |
%cast_28 = tensor.cast %230 : tensor<?xf32> to tensor<8xf32> | |
%231 = tensor.empty() : tensor<8xi64> | |
%232 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} outs(%231 : tensor<8xi64>) { | |
^bb0(%out: i64): | |
%342 = linalg.index 0 : index | |
%343 = arith.index_cast %342 : index to i64 | |
linalg.yield %343 : i64 | |
} -> tensor<8xi64> | |
%233:2 = iree_linalg_ext.sort dimension(0) outs(%170, %232 : tensor<8xf32>, tensor<8xi64>) { | |
^bb0(%arg11: f32, %arg12: f32, %arg13: i64, %arg14: i64): | |
%342 = arith.cmpf oge, %arg11, %arg12 : f32 | |
iree_linalg_ext.yield %342 : i1 | |
} -> tensor<8xf32>, tensor<8xi64> | |
%234 = torch_c.from_builtin_tensor %233#1 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%235 = torch_c.to_builtin_tensor %234 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%236 = torch_c.from_builtin_tensor %231 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%237 = torch_c.to_builtin_tensor %236 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%c0_29 = arith.constant 0 : index | |
%c1_30 = arith.constant 1 : index | |
%238 = arith.index_cast %c8_i64 : i64 to index | |
%239:3 = scf.for %arg11 = %c0_29 to %238 step %c1_30 iter_args(%arg12 = %6, %arg13 = %237, %arg14 = %c0_i64) -> (tensor<8xi64>, tensor<8xi64>, i64) { | |
%342 = arith.index_cast %arg11 : index to i64 | |
%343 = torch_c.from_i64 %342 | |
%344 = torch_c.from_builtin_tensor %arg12 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%345 = torch_c.to_builtin_tensor %344 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%346 = torch_c.from_builtin_tensor %arg13 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%347 = torch_c.to_builtin_tensor %346 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%348 = torch_c.from_i64 %arg14 | |
%349 = torch_c.to_i64 %348 | |
%350 = torch_c.to_i64 %343 | |
%351 = torch_c.to_i64 %348 | |
%352 = torch_c.to_builtin_tensor %346 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%353 = torch_c.to_builtin_tensor %344 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%354 = torch.aten.lt.int %343, %int0 : !torch.int, !torch.int -> !torch.bool | |
%355 = torch.aten.Int.bool %354 : !torch.bool -> !torch.int | |
%356 = torch.aten.mul.int %355, %int8 : !torch.int, !torch.int -> !torch.int | |
%357 = torch.aten.add.int %343, %356 : !torch.int, !torch.int -> !torch.int | |
%358 = torch_c.to_i64 %357 | |
%359 = torch.aten.add.int %357, %int1 : !torch.int, !torch.int -> !torch.int | |
%360 = torch_c.to_i64 %359 | |
%c8_i64_43 = arith.constant 8 : i64 | |
%361 = arith.addi %358, %c8_i64_43 : i64 | |
%362 = arith.cmpi sge, %358, %c0_i64_0 : i64 | |
%363 = arith.select %362, %358, %361 : i64 | |
%364 = arith.cmpi slt, %363, %c0_i64_0 : i64 | |
%365 = arith.select %364, %c0_i64_0, %363 : i64 | |
%366 = arith.cmpi sgt, %365, %c8_i64_43 : i64 | |
%367 = arith.select %366, %c8_i64_43, %365 : i64 | |
%368 = arith.index_cast %367 : i64 to index | |
%369 = arith.index_cast %360 : i64 to index | |
%370 = arith.cmpi slt, %369, %c0 : index | |
%371 = arith.addi %369, %c8 : index | |
%372 = arith.select %370, %371, %369 : index | |
%373 = arith.cmpi slt, %372, %c0 : index | |
%374 = arith.select %373, %c-1, %372 : index | |
%375 = arith.cmpi sgt, %374, %c8 : index | |
%376 = arith.select %375, %c8, %374 : index | |
%377 = arith.subi %376, %368 : index | |
%378 = arith.addi %377, %48 : index | |
%379 = arith.subi %378, %67 : index | |
%380 = arith.floordivsi %379, %48 : index | |
%381 = arith.cmpi slt, %380, %c0 : index | |
%382 = arith.select %381, %c0, %380 : index | |
%383 = arith.subi %c8, %c1 : index | |
%384 = linalg.fill ins(%c0_i64_0 : i64) outs(%231 : tensor<8xi64>) -> tensor<8xi64> | |
%385 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%353 : tensor<8xi64>) outs(%384 : tensor<8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%402 = linalg.index 0 : index | |
%403 = arith.subi %383, %402 : index | |
%extracted_48 = tensor.extract %353[%403] : tensor<8xi64> | |
linalg.yield %extracted_48 : i64 | |
} -> tensor<8xi64> | |
%386 = arith.muli %382, %78 : index | |
%c8_44 = arith.constant 8 : index | |
%387 = arith.subi %c8_44, %386 : index | |
%388 = arith.select %77, %387, %368 : index | |
%389 = arith.select %77, %385, %353 : tensor<8xi64> | |
%extracted_slice_45 = tensor.extract_slice %389[%388] [%382] [%78] : tensor<8xi64> to tensor<?xi64> | |
%cast_46 = tensor.cast %extracted_slice_45 : tensor<?xi64> to tensor<1xi64> | |
%390 = torch_c.from_builtin_tensor %cast_46 : tensor<1xi64> -> !torch.vtensor<[1],si64> | |
%391 = torch_c.to_builtin_tensor %390 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%extracted_47 = tensor.extract %391[%c0] : tensor<1xi64> | |
%392 = torch_c.from_i64 %extracted_47 | |
%393 = torch.aten.Bool.int %392 : !torch.int -> !torch.bool | |
%394 = torch_c.to_i1 %393 | |
%395:3 = scf.if %394 -> (tensor<8xi64>, tensor<8xi64>, i64) { | |
%402 = torch.aten.lt.int %343, %int0 : !torch.int, !torch.int -> !torch.bool | |
%403 = torch.aten.Int.bool %402 : !torch.bool -> !torch.int | |
%404 = torch.aten.mul.int %403, %int8 : !torch.int, !torch.int -> !torch.int | |
%405 = torch.aten.add.int %343, %404 : !torch.int, !torch.int -> !torch.int | |
%406 = torch_c.to_i64 %405 | |
%407 = torch.aten.add.int %405, %int1 : !torch.int, !torch.int -> !torch.int | |
%408 = torch_c.to_i64 %407 | |
%409 = arith.addi %406, %c8_i64_43 : i64 | |
%410 = arith.cmpi sge, %406, %c0_i64_0 : i64 | |
%411 = arith.select %410, %406, %409 : i64 | |
%412 = arith.cmpi slt, %411, %c0_i64_0 : i64 | |
%413 = arith.select %412, %c0_i64_0, %411 : i64 | |
%414 = arith.cmpi sgt, %413, %c8_i64_43 : i64 | |
%415 = arith.select %414, %c8_i64_43, %413 : i64 | |
%416 = arith.index_cast %415 : i64 to index | |
%417 = arith.index_cast %408 : i64 to index | |
%418 = arith.cmpi slt, %417, %c0 : index | |
%419 = arith.addi %417, %c8 : index | |
%420 = arith.select %418, %419, %417 : index | |
%421 = arith.cmpi slt, %420, %c0 : index | |
%422 = arith.select %421, %c-1, %420 : index | |
%423 = arith.cmpi sgt, %422, %c8 : index | |
%424 = arith.select %423, %c8, %422 : index | |
%425 = arith.subi %424, %416 : index | |
%426 = arith.addi %425, %48 : index | |
%427 = arith.subi %426, %67 : index | |
%428 = arith.floordivsi %427, %48 : index | |
%429 = arith.cmpi slt, %428, %c0 : index | |
%430 = arith.select %429, %c0, %428 : index | |
%431 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%235 : tensor<8xi64>) outs(%384 : tensor<8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%547 = linalg.index 0 : index | |
%548 = arith.subi %383, %547 : index | |
%extracted_69 = tensor.extract %235[%548] : tensor<8xi64> | |
linalg.yield %extracted_69 : i64 | |
} -> tensor<8xi64> | |
%432 = arith.muli %430, %78 : index | |
%c8_48 = arith.constant 8 : index | |
%433 = arith.subi %c8_48, %432 : index | |
%434 = arith.select %77, %433, %416 : index | |
%435 = arith.select %77, %431, %235 : tensor<8xi64> | |
%extracted_slice_49 = tensor.extract_slice %435[%434] [%430] [%78] : tensor<8xi64> to tensor<?xi64> | |
%cast_50 = tensor.cast %extracted_slice_49 : tensor<?xi64> to tensor<1xi64> | |
%436 = torch_c.from_builtin_tensor %cast_50 : tensor<1xi64> -> !torch.vtensor<[1],si64> | |
%437 = torch_c.to_builtin_tensor %436 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%438 = torch.aten.add.int %348, %int1 : !torch.int, !torch.int -> !torch.int | |
%439 = torch_c.to_i64 %438 | |
%440 = torch_c.to_i64 %438 | |
%441 = arith.addi %351, %c8_i64_43 : i64 | |
%442 = arith.cmpi sge, %351, %c0_i64_0 : i64 | |
%443 = arith.select %442, %351, %441 : i64 | |
%444 = arith.cmpi slt, %443, %c0_i64_0 : i64 | |
%445 = arith.select %444, %c0_i64_0, %443 : i64 | |
%446 = arith.cmpi sgt, %445, %c8_i64_43 : i64 | |
%447 = arith.select %446, %c8_i64_43, %445 : i64 | |
%448 = arith.index_cast %447 : i64 to index | |
%449 = arith.index_cast %440 : i64 to index | |
%450 = arith.cmpi slt, %449, %c0 : index | |
%451 = arith.addi %449, %c8 : index | |
%452 = arith.select %450, %451, %449 : index | |
%453 = arith.cmpi slt, %452, %c0 : index | |
%454 = arith.select %453, %c-1, %452 : index | |
%455 = arith.cmpi sgt, %454, %c8 : index | |
%456 = arith.select %455, %c8, %454 : index | |
%457 = arith.subi %456, %448 : index | |
%458 = arith.addi %457, %48 : index | |
%459 = arith.subi %458, %67 : index | |
%460 = arith.floordivsi %459, %48 : index | |
%461 = arith.cmpi slt, %460, %c0 : index | |
%462 = arith.select %461, %c0, %460 : index | |
%cast_51 = tensor.cast %cast_50 : tensor<1xi64> to tensor<?xi64> | |
%inserted_slice_52 = tensor.insert_slice %cast_51 into %352[%448] [%462] [%48] : tensor<?xi64> into tensor<8xi64> | |
%463 = torch_c.from_builtin_tensor %inserted_slice_52 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%464 = torch_c.to_builtin_tensor %463 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%extracted_53 = tensor.extract %437[%c0] : tensor<1xi64> | |
%465 = torch_c.from_i64 %extracted_53 | |
%466 = torch_c.to_i64 %465 | |
%467 = torch.aten.add.int %465, %int1 : !torch.int, !torch.int -> !torch.int | |
%468 = torch_c.to_i64 %467 | |
%469 = arith.addi %466, %c8_i64_43 : i64 | |
%470 = arith.cmpi sge, %466, %c0_i64_0 : i64 | |
%471 = arith.select %470, %466, %469 : i64 | |
%472 = arith.cmpi slt, %471, %c0_i64_0 : i64 | |
%473 = arith.select %472, %c0_i64_0, %471 : i64 | |
%474 = arith.cmpi sgt, %473, %c8_i64_43 : i64 | |
%475 = arith.select %474, %c8_i64_43, %473 : i64 | |
%476 = arith.index_cast %475 : i64 to index | |
%477 = arith.index_cast %468 : i64 to index | |
%478 = arith.cmpi slt, %477, %c0 : index | |
%479 = arith.addi %477, %c8 : index | |
%480 = arith.select %478, %479, %477 : index | |
%481 = arith.cmpi slt, %480, %c0 : index | |
%482 = arith.select %481, %c-1, %480 : index | |
%483 = arith.cmpi sgt, %482, %c8 : index | |
%484 = arith.select %483, %c8, %482 : index | |
%485 = arith.subi %484, %476 : index | |
%486 = arith.addi %485, %48 : index | |
%487 = arith.subi %486, %67 : index | |
%488 = arith.floordivsi %487, %48 : index | |
%489 = arith.cmpi slt, %488, %c0 : index | |
%490 = arith.select %489, %c0, %488 : index | |
%491 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%collapsed : tensor<8x4xf32>) outs(%195 : tensor<8x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%547 = linalg.index 0 : index | |
%548 = linalg.index 1 : index | |
%549 = arith.subi %383, %547 : index | |
%extracted_69 = tensor.extract %collapsed[%549, %548] : tensor<8x4xf32> | |
linalg.yield %extracted_69 : f32 | |
} -> tensor<8x4xf32> | |
%492 = arith.muli %490, %78 : index | |
%c8_54 = arith.constant 8 : index | |
%493 = arith.subi %c8_54, %492 : index | |
%494 = arith.select %77, %493, %476 : index | |
%495 = arith.select %77, %491, %collapsed : tensor<8x4xf32> | |
%extracted_slice_55 = tensor.extract_slice %495[%494, %c0] [%490, %c4] [%78, %c1] : tensor<8x4xf32> to tensor<?x?xf32> | |
%cast_56 = tensor.cast %extracted_slice_55 : tensor<?x?xf32> to tensor<1x4xf32> | |
%496 = tensor.empty() : tensor<1x4xf32> | |
%497 = linalg.fill ins(%cst_7 : f32) outs(%496 : tensor<1x4xf32>) -> tensor<1x4xf32> | |
%498 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_56 : tensor<1x4xf32>) outs(%497 : tensor<1x4xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%547 = linalg.index 0 : index | |
%548 = linalg.index 1 : index | |
%549 = arith.subi %193, %548 : index | |
%extracted_69 = tensor.extract %cast_56[%547, %549] : tensor<1x4xf32> | |
linalg.yield %extracted_69 : f32 | |
} -> tensor<1x4xf32> | |
%c4_57 = arith.constant 4 : index | |
%499 = arith.subi %c4_57, %197 : index | |
%500 = arith.select %77, %499, %178 : index | |
%501 = arith.select %77, %498, %cast_56 : tensor<1x4xf32> | |
%extracted_slice_58 = tensor.extract_slice %501[%c0, %500] [%c1, %192] [%c1, %78] : tensor<1x4xf32> to tensor<?x?xf32> | |
%cast_59 = tensor.cast %extracted_slice_58 : tensor<?x?xf32> to tensor<1x2xf32> | |
%502 = arith.subi %c4_57, %223 : index | |
%503 = arith.select %77, %502, %208 : index | |
%extracted_slice_60 = tensor.extract_slice %501[%c0, %503] [%c1, %222] [%c1, %78] : tensor<1x4xf32> to tensor<?x?xf32> | |
%cast_61 = tensor.cast %extracted_slice_60 : tensor<?x?xf32> to tensor<1x2xf32> | |
%504 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_23, %cast_59 : tensor<8x2xf32>, tensor<1x2xf32>) outs(%226 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_69: f32, %out: f32): | |
%547 = arith.cmpf ogt, %in, %in_69 : f32 | |
%548 = arith.select %547, %in, %in_69 : f32 | |
linalg.yield %548 : f32 | |
} -> tensor<8x2xf32> | |
%505 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_25, %cast_61 : tensor<8x2xf32>, tensor<1x2xf32>) outs(%226 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_69: f32, %out: f32): | |
%547 = arith.cmpf olt, %in, %in_69 : f32 | |
%548 = arith.select %547, %in, %in_69 : f32 | |
linalg.yield %548 : f32 | |
} -> tensor<8x2xf32> | |
%506 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%505, %504 : tensor<8x2xf32>, tensor<8x2xf32>) outs(%226 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_69: f32, %out: f32): | |
%547 = arith.sitofp %c1_i64 : i64 to f32 | |
%548 = arith.mulf %in_69, %547 : f32 | |
%549 = arith.subf %in, %548 : f32 | |
linalg.yield %549 : f32 | |
} -> tensor<8x2xf32> | |
%507 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%506, %8 : tensor<8x2xf32>, tensor<1xf32>) outs(%226 : tensor<8x2xf32>) { | |
^bb0(%in: f32, %in_69: f32, %out: f32): | |
%547 = arith.cmpf ogt, %in, %in_69 : f32 | |
%548 = arith.select %547, %in, %in_69 : f32 | |
linalg.yield %548 : f32 | |
} -> tensor<8x2xf32> | |
%c8_62 = arith.constant 8 : index | |
%508 = tensor.empty(%c8_62) : tensor<?xf32> | |
%509 = linalg.fill ins(%cst_26 : f32) outs(%508 : tensor<?xf32>) -> tensor<?xf32> | |
%510 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["parallel", "reduction"]} ins(%507 : tensor<8x2xf32>) outs(%509 : tensor<?xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%547 = arith.mulf %in, %out : f32 | |
linalg.yield %547 : f32 | |
} -> tensor<?xf32> | |
%cast_63 = tensor.cast %510 : tensor<?xf32> to tensor<8xf32> | |
%511 = torch.aten.add.int %343, %int1 : !torch.int, !torch.int -> !torch.int | |
%512 = torch_c.to_i64 %511 | |
%513 = arith.addi %350, %c8_i64_43 : i64 | |
%514 = arith.cmpi sge, %350, %c0_i64_0 : i64 | |
%515 = arith.select %514, %350, %513 : i64 | |
%516 = arith.cmpi slt, %515, %c0_i64_0 : i64 | |
%517 = arith.select %516, %c0_i64_0, %515 : i64 | |
%518 = arith.cmpi sgt, %517, %c8_i64_43 : i64 | |
%519 = arith.select %518, %c8_i64_43, %517 : i64 | |
%520 = arith.index_cast %519 : i64 to index | |
%521 = arith.index_cast %512 : i64 to index | |
%522 = arith.cmpi slt, %521, %c0 : index | |
%523 = arith.addi %521, %c8 : index | |
%524 = arith.select %522, %523, %521 : index | |
%525 = arith.cmpi slt, %524, %c0 : index | |
%526 = arith.select %525, %c-1, %524 : index | |
%527 = arith.cmpi sgt, %526, %c8 : index | |
%528 = arith.select %527, %c8, %526 : index | |
%529 = arith.subi %528, %520 : index | |
%530 = arith.addi %529, %48 : index | |
%531 = arith.subi %530, %67 : index | |
%532 = arith.floordivsi %531, %48 : index | |
%533 = arith.cmpi slt, %532, %c0 : index | |
%534 = arith.select %533, %c0, %532 : index | |
%535 = tensor.empty() : tensor<8xf32> | |
%536 = linalg.fill ins(%cst_7 : f32) outs(%535 : tensor<8xf32>) -> tensor<8xf32> | |
%537 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_28 : tensor<8xf32>) outs(%536 : tensor<8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%547 = linalg.index 0 : index | |
%548 = arith.subi %383, %547 : index | |
%extracted_69 = tensor.extract %cast_28[%548] : tensor<8xf32> | |
linalg.yield %extracted_69 : f32 | |
} -> tensor<8xf32> | |
%538 = arith.muli %534, %78 : index | |
%c8_64 = arith.constant 8 : index | |
%539 = arith.subi %c8_64, %538 : index | |
%540 = arith.select %77, %539, %520 : index | |
%541 = arith.select %77, %537, %cast_28 : tensor<8xf32> | |
%extracted_slice_65 = tensor.extract_slice %541[%540] [%534] [%78] : tensor<8xf32> to tensor<?xf32> | |
%cast_66 = tensor.cast %extracted_slice_65 : tensor<?xf32> to tensor<1xf32> | |
%542 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_28, %cast_66 : tensor<8xf32>, tensor<1xf32>) outs(%535 : tensor<8xf32>) { | |
^bb0(%in: f32, %in_69: f32, %out: f32): | |
%547 = arith.sitofp %c1_i64 : i64 to f32 | |
%548 = arith.mulf %in_69, %547 : f32 | |
%549 = arith.addf %in, %548 : f32 | |
linalg.yield %549 : f32 | |
} -> tensor<8xf32> | |
%543 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%542, %cast_63 : tensor<8xf32>, tensor<8xf32>) outs(%535 : tensor<8xf32>) { | |
^bb0(%in: f32, %in_69: f32, %out: f32): | |
%547 = arith.sitofp %c1_i64 : i64 to f32 | |
%548 = arith.mulf %in_69, %547 : f32 | |
%549 = arith.subf %in, %548 : f32 | |
linalg.yield %549 : f32 | |
} -> tensor<8xf32> | |
%544 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%cast_63, %543 : tensor<8xf32>, tensor<8xf32>) outs(%535 : tensor<8xf32>) { | |
^bb0(%in: f32, %in_69: f32, %out: f32): | |
%547 = arith.divf %in, %in_69 : f32 | |
linalg.yield %547 : f32 | |
} -> tensor<8xf32> | |
%c0_67 = arith.constant 0 : index | |
%c1_68 = arith.constant 1 : index | |
%545 = arith.index_cast %c8_i64 : i64 to index | |
%546 = scf.for %arg15 = %c0_67 to %545 step %c1_68 iter_args(%arg16 = %345) -> (tensor<8xi64>) { | |
%547 = arith.index_cast %arg15 : index to i64 | |
%548 = torch_c.from_i64 %547 | |
%549 = torch_c.from_builtin_tensor %arg16 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%550 = torch_c.to_builtin_tensor %549 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%551 = torch_c.to_builtin_tensor %549 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%552 = torch.aten.add.int %548, %343 : !torch.int, !torch.int -> !torch.int | |
%553 = torch.aten.add.int %552, %int1 : !torch.int, !torch.int -> !torch.int | |
%554 = torch_c.to_i64 %553 | |
%555 = torch.aten.lt.int %553, %int8 : !torch.int, !torch.int -> !torch.bool | |
%556 = torch_c.to_i1 %555 | |
%557 = scf.if %556 -> (tensor<8xi64>) { | |
%560 = torch.aten.lt.int %553, %int0 : !torch.int, !torch.int -> !torch.bool | |
%561 = torch.aten.Int.bool %560 : !torch.bool -> !torch.int | |
%562 = torch.aten.mul.int %561, %int8 : !torch.int, !torch.int -> !torch.int | |
%563 = torch.aten.add.int %553, %562 : !torch.int, !torch.int -> !torch.int | |
%564 = torch_c.to_i64 %563 | |
%565 = torch.aten.add.int %563, %int1 : !torch.int, !torch.int -> !torch.int | |
%566 = torch_c.to_i64 %565 | |
%567 = arith.addi %564, %c8_i64_43 : i64 | |
%568 = arith.cmpi sge, %564, %c0_i64_0 : i64 | |
%569 = arith.select %568, %564, %567 : i64 | |
%570 = arith.cmpi slt, %569, %c0_i64_0 : i64 | |
%571 = arith.select %570, %c0_i64_0, %569 : i64 | |
%572 = arith.cmpi sgt, %571, %c8_i64_43 : i64 | |
%573 = arith.select %572, %c8_i64_43, %571 : i64 | |
%574 = arith.index_cast %573 : i64 to index | |
%575 = arith.index_cast %566 : i64 to index | |
%576 = arith.cmpi slt, %575, %c0 : index | |
%577 = arith.addi %575, %c8 : index | |
%578 = arith.select %576, %577, %575 : index | |
%579 = arith.cmpi slt, %578, %c0 : index | |
%580 = arith.select %579, %c-1, %578 : index | |
%581 = arith.cmpi sgt, %580, %c8 : index | |
%582 = arith.select %581, %c8, %580 : index | |
%583 = arith.subi %582, %574 : index | |
%584 = arith.addi %583, %48 : index | |
%585 = arith.subi %584, %67 : index | |
%586 = arith.floordivsi %585, %48 : index | |
%587 = arith.cmpi slt, %586, %c0 : index | |
%588 = arith.select %587, %c0, %586 : index | |
%589 = arith.muli %588, %78 : index | |
%590 = arith.subi %c8_48, %589 : index | |
%591 = arith.select %77, %590, %574 : index | |
%extracted_slice_69 = tensor.extract_slice %435[%591] [%588] [%78] : tensor<8xi64> to tensor<?xi64> | |
%cast_70 = tensor.cast %extracted_slice_69 : tensor<?xi64> to tensor<1xi64> | |
%592 = torch_c.from_builtin_tensor %cast_70 : tensor<1xi64> -> !torch.vtensor<[1],si64> | |
%593 = torch_c.to_builtin_tensor %592 : !torch.vtensor<[1],si64> -> tensor<1xi64> | |
%extracted_71 = tensor.extract %593[%c0] : tensor<1xi64> | |
%594 = torch_c.from_i64 %extracted_71 | |
%595 = torch_c.to_i64 %594 | |
%596 = torch.aten.add.int %594, %int1 : !torch.int, !torch.int -> !torch.int | |
%597 = torch_c.to_i64 %596 | |
%598 = arith.addi %595, %c8_i64_43 : i64 | |
%599 = arith.cmpi sge, %595, %c0_i64_0 : i64 | |
%600 = arith.select %599, %595, %598 : i64 | |
%601 = arith.cmpi slt, %600, %c0_i64_0 : i64 | |
%602 = arith.select %601, %c0_i64_0, %600 : i64 | |
%603 = arith.cmpi sgt, %602, %c8_i64_43 : i64 | |
%604 = arith.select %603, %c8_i64_43, %602 : i64 | |
%605 = arith.index_cast %604 : i64 to index | |
%606 = arith.index_cast %597 : i64 to index | |
%607 = arith.cmpi slt, %606, %c0 : index | |
%608 = arith.addi %606, %c8 : index | |
%609 = arith.select %607, %608, %606 : index | |
%610 = arith.cmpi slt, %609, %c0 : index | |
%611 = arith.select %610, %c-1, %609 : index | |
%612 = arith.cmpi sgt, %611, %c8 : index | |
%613 = arith.select %612, %c8, %611 : index | |
%614 = arith.subi %613, %605 : index | |
%615 = arith.addi %614, %48 : index | |
%616 = arith.subi %615, %67 : index | |
%617 = arith.floordivsi %616, %48 : index | |
%618 = arith.cmpi slt, %617, %c0 : index | |
%619 = arith.select %618, %c0, %617 : index | |
%620 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%544 : tensor<8xf32>) outs(%536 : tensor<8xf32>) { | |
^bb0(%in: f32, %out: f32): | |
%632 = linalg.index 0 : index | |
%633 = arith.subi %383, %632 : index | |
%extracted_76 = tensor.extract %544[%633] : tensor<8xf32> | |
linalg.yield %extracted_76 : f32 | |
} -> tensor<8xf32> | |
%621 = arith.muli %619, %78 : index | |
%c8_72 = arith.constant 8 : index | |
%622 = arith.subi %c8_72, %621 : index | |
%623 = arith.select %77, %622, %605 : index | |
%624 = arith.select %77, %620, %544 : tensor<8xf32> | |
%extracted_slice_73 = tensor.extract_slice %624[%623] [%619] [%78] : tensor<8xf32> to tensor<?xf32> | |
%cast_74 = tensor.cast %extracted_slice_73 : tensor<?xf32> to tensor<1xf32> | |
%625 = torch_c.from_builtin_tensor %cast_74 : tensor<1xf32> -> !torch.vtensor<[1],f32> | |
%626 = torch_c.to_builtin_tensor %625 : !torch.vtensor<[1],f32> -> tensor<1xf32> | |
%extracted_75 = tensor.extract %626[%c0] : tensor<1xf32> | |
%627 = arith.extf %extracted_75 : f32 to f64 | |
%628 = torch_c.from_f64 %627 | |
%629 = torch.aten.gt.float %628, %22 : !torch.float, !torch.float -> !torch.bool | |
%630 = torch_c.to_i1 %629 | |
%631 = scf.if %630 -> (tensor<8xi64>) { | |
%632 = torch.aten.add.int %553, %int1 : !torch.int, !torch.int -> !torch.int | |
%633 = torch_c.to_i64 %632 | |
%634 = arith.addi %554, %c8_i64_43 : i64 | |
%635 = arith.cmpi sge, %554, %c0_i64_0 : i64 | |
%636 = arith.select %635, %554, %634 : i64 | |
%637 = arith.cmpi slt, %636, %c0_i64_0 : i64 | |
%638 = arith.select %637, %c0_i64_0, %636 : i64 | |
%639 = arith.cmpi sgt, %638, %c8_i64_43 : i64 | |
%640 = arith.select %639, %c8_i64_43, %638 : i64 | |
%641 = arith.index_cast %640 : i64 to index | |
%642 = arith.index_cast %633 : i64 to index | |
%643 = arith.cmpi slt, %642, %c0 : index | |
%644 = arith.addi %642, %c8 : index | |
%645 = arith.select %643, %644, %642 : index | |
%646 = arith.cmpi slt, %645, %c0 : index | |
%647 = arith.select %646, %c-1, %645 : index | |
%648 = arith.cmpi sgt, %647, %c8 : index | |
%649 = arith.select %648, %c8, %647 : index | |
%650 = arith.subi %649, %641 : index | |
%651 = arith.addi %650, %48 : index | |
%652 = arith.subi %651, %67 : index | |
%653 = arith.floordivsi %652, %48 : index | |
%654 = arith.cmpi slt, %653, %c0 : index | |
%655 = arith.select %654, %c0, %653 : index | |
%cast_76 = tensor.cast %10 : tensor<1xi64> to tensor<?xi64> | |
%inserted_slice_77 = tensor.insert_slice %cast_76 into %551[%641] [%655] [%48] : tensor<?xi64> into tensor<8xi64> | |
%656 = torch_c.from_builtin_tensor %inserted_slice_77 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%657 = torch_c.to_builtin_tensor %656 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
scf.yield %657 : tensor<8xi64> | |
} else { | |
scf.yield %550 : tensor<8xi64> | |
} | |
scf.yield %631 : tensor<8xi64> | |
} else { | |
scf.yield %550 : tensor<8xi64> | |
} | |
%558 = torch_c.from_builtin_tensor %557 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%559 = torch_c.to_builtin_tensor %558 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
scf.yield %559 : tensor<8xi64> | |
} | |
scf.yield %546, %464, %439 : tensor<8xi64>, tensor<8xi64>, i64 | |
} else { | |
scf.yield %345, %347, %349 : tensor<8xi64>, tensor<8xi64>, i64 | |
} | |
%396 = torch_c.from_i64 %395#2 | |
%397 = torch_c.from_builtin_tensor %395#1 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%398 = torch_c.from_builtin_tensor %395#0 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%399 = torch_c.to_builtin_tensor %398 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%400 = torch_c.to_builtin_tensor %397 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%401 = torch_c.to_i64 %396 | |
scf.yield %399, %400, %401 : tensor<8xi64>, tensor<8xi64>, i64 | |
} | |
%240 = torch_c.from_i64 %239#2 | |
%241 = torch_c.from_builtin_tensor %239#1 : tensor<8xi64> -> !torch.vtensor<[8],si64> | |
%242 = torch_c.to_i64 %240 | |
%243 = torch_c.to_builtin_tensor %241 : !torch.vtensor<[8],si64> -> tensor<8xi64> | |
%c8_i64_31 = arith.constant 8 : i64 | |
%244 = arith.addi %c0_i64_0, %c8_i64_31 : i64 | |
%245 = arith.select %172, %c0_i64_0, %244 : i64 | |
%246 = arith.cmpi slt, %245, %c0_i64_0 : i64 | |
%247 = arith.select %246, %c0_i64_0, %245 : i64 | |
%248 = arith.cmpi sgt, %247, %c8_i64_31 : i64 | |
%249 = arith.select %248, %c8_i64_31, %247 : i64 | |
%250 = arith.index_cast %249 : i64 to index | |
%251 = arith.index_cast %242 : i64 to index | |
%252 = arith.cmpi slt, %251, %c0 : index | |
%253 = arith.addi %251, %c8 : index | |
%254 = arith.select %252, %253, %251 : index | |
%255 = arith.cmpi slt, %254, %c0 : index | |
%256 = arith.select %255, %c-1, %254 : index | |
%257 = arith.cmpi sgt, %256, %c8 : index | |
%258 = arith.select %257, %c8, %256 : index | |
%259 = arith.subi %258, %250 : index | |
%260 = arith.addi %259, %48 : index | |
%261 = arith.subi %260, %67 : index | |
%262 = arith.floordivsi %261, %48 : index | |
%263 = arith.cmpi slt, %262, %c0 : index | |
%264 = arith.select %263, %c0, %262 : index | |
%265 = arith.subi %c8, %c1 : index | |
%266 = linalg.fill ins(%c0_i64_0 : i64) outs(%231 : tensor<8xi64>) -> tensor<8xi64> | |
%267 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%243 : tensor<8xi64>) outs(%266 : tensor<8xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%342 = linalg.index 0 : index | |
%343 = arith.subi %265, %342 : index | |
%extracted_43 = tensor.extract %243[%343] : tensor<8xi64> | |
linalg.yield %extracted_43 : i64 | |
} -> tensor<8xi64> | |
%268 = arith.muli %264, %78 : index | |
%c8_32 = arith.constant 8 : index | |
%269 = arith.subi %c8_32, %268 : index | |
%270 = arith.select %77, %269, %250 : index | |
%271 = arith.select %77, %267, %243 : tensor<8xi64> | |
%extracted_slice_33 = tensor.extract_slice %271[%270] [%264] [%78] : tensor<8xi64> to tensor<?xi64> | |
%272 = torch_c.from_builtin_tensor %extracted_slice_33 : tensor<?xi64> -> !torch.vtensor<[?],si64> | |
%273 = torch_c.to_builtin_tensor %272 : !torch.vtensor<[?],si64> -> tensor<?xi64> | |
%274 = arith.addi %c0_i64_0, %c1_i64 : i64 | |
%275 = arith.select %172, %c0_i64_0, %274 : i64 | |
%276 = arith.index_cast %275 : i64 to index | |
%dim = tensor.dim %extracted_slice_33, %276 : tensor<?xi64> | |
%277 = arith.index_cast %dim : index to i64 | |
%278 = torch_c.from_i64 %277 | |
%279 = torch.aten.gt.int %278, %23 : !torch.int, !torch.int -> !torch.bool | |
%280 = torch_c.to_i1 %279 | |
%281 = scf.if %280 -> (tensor<?xi64>) { | |
%342 = arith.index_cast %264 : index to i64 | |
%343 = arith.addi %c0_i64_0, %342 : i64 | |
%344 = arith.select %172, %c0_i64_0, %343 : i64 | |
%345 = arith.cmpi slt, %344, %c0_i64_0 : i64 | |
%346 = arith.select %345, %c0_i64_0, %344 : i64 | |
%347 = arith.cmpi sgt, %346, %342 : i64 | |
%348 = arith.select %347, %342, %346 : i64 | |
%349 = arith.index_cast %348 : i64 to index | |
%350 = arith.index_cast %24 : i64 to index | |
%351 = arith.cmpi slt, %350, %c0 : index | |
%352 = arith.addi %350, %264 : index | |
%353 = arith.select %351, %352, %350 : index | |
%354 = arith.cmpi slt, %353, %c0 : index | |
%355 = arith.select %354, %c-1, %353 : index | |
%356 = arith.cmpi sgt, %355, %264 : index | |
%357 = arith.select %356, %264, %355 : index | |
%358 = arith.subi %357, %349 : index | |
%359 = arith.addi %358, %48 : index | |
%360 = arith.subi %359, %67 : index | |
%361 = arith.floordivsi %360, %48 : index | |
%362 = arith.cmpi slt, %361, %c0 : index | |
%363 = arith.select %362, %c0, %361 : index | |
%364 = arith.subi %264, %c1 : index | |
%365 = tensor.empty(%264) : tensor<?xi64> | |
%366 = linalg.fill ins(%c0_i64_0 : i64) outs(%365 : tensor<?xi64>) -> tensor<?xi64> | |
%367 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%extracted_slice_33 : tensor<?xi64>) outs(%366 : tensor<?xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%374 = linalg.index 0 : index | |
%375 = arith.subi %364, %374 : index | |
%extracted_44 = tensor.extract %extracted_slice_33[%375] : tensor<?xi64> | |
linalg.yield %extracted_44 : i64 | |
} -> tensor<?xi64> | |
%368 = arith.muli %363, %78 : index | |
%369 = arith.subi %264, %368 : index | |
%370 = arith.select %77, %369, %349 : index | |
%371 = arith.select %77, %367, %extracted_slice_33 : tensor<?xi64> | |
%extracted_slice_43 = tensor.extract_slice %371[%370] [%363] [%78] : tensor<?xi64> to tensor<?xi64> | |
%372 = torch_c.from_builtin_tensor %extracted_slice_43 : tensor<?xi64> -> !torch.vtensor<[?],si64> | |
%373 = torch_c.to_builtin_tensor %372 : !torch.vtensor<[?],si64> -> tensor<?xi64> | |
scf.yield %373 : tensor<?xi64> | |
} else { | |
scf.yield %273 : tensor<?xi64> | |
} | |
%282 = torch_c.from_builtin_tensor %281 : tensor<?xi64> -> !torch.vtensor<[?],si64> | |
%283 = torch_c.to_builtin_tensor %282 : !torch.vtensor<[?],si64> -> tensor<?xi64> | |
%dim_34 = tensor.dim %283, %c0 : tensor<?xi64> | |
%expanded = tensor.expand_shape %283 [[0, 1]] output_shape [%dim_34, 1] : tensor<?xi64> into tensor<?x1xi64> | |
%284 = torch_c.from_builtin_tensor %expanded : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64> | |
%285 = arith.addi %c0_i64_0, %c2_i64 : i64 | |
%286 = arith.select %172, %c0_i64_0, %285 : i64 | |
%287 = arith.index_cast %286 : i64 to index | |
%dim_35 = tensor.dim %expanded, %287 : tensor<?x1xi64> | |
%288 = arith.index_cast %dim_35 : index to i64 | |
%289 = torch_c.from_i64 %288 | |
%290 = torch_c.to_i64 %289 | |
%291 = arith.index_cast %290 : i64 to index | |
%292 = tensor.empty(%291) : tensor<?x1xi64> | |
%dim_36 = tensor.dim %292, %c0 : tensor<?x1xi64> | |
%293 = tensor.empty(%dim_36) : tensor<?x1xi64> | |
%294 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, 0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%292 : tensor<?x1xi64>) outs(%293 : tensor<?x1xi64>) { | |
^bb0(%in: i64, %out: i64): | |
linalg.yield %40 : i64 | |
} -> tensor<?x1xi64> | |
%295 = torch_c.from_builtin_tensor %294 : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64> | |
%296 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, 0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%292 : tensor<?x1xi64>) outs(%293 : tensor<?x1xi64>) { | |
^bb0(%in: i64, %out: i64): | |
linalg.yield %95 : i64 | |
} -> tensor<?x1xi64> | |
%297 = torch_c.from_builtin_tensor %296 : tensor<?x1xi64> -> !torch.vtensor<[?,1],si64> | |
%298 = torch_c.to_builtin_tensor %297 : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64> | |
%299 = torch_c.to_builtin_tensor %284 : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64> | |
%concat = tensor.concat dim(1) %298, %299 : (tensor<?x1xi64>, tensor<?x1xi64>) -> tensor<?x2xi64> | |
%300 = torch_c.from_builtin_tensor %concat : tensor<?x2xi64> -> !torch.vtensor<[?,2],si64> | |
%301 = torch_c.to_builtin_tensor %295 : !torch.vtensor<[?,1],si64> -> tensor<?x1xi64> | |
%302 = torch_c.to_builtin_tensor %300 : !torch.vtensor<[?,2],si64> -> tensor<?x2xi64> | |
%concat_37 = tensor.concat dim(1) %301, %302 : (tensor<?x1xi64>, tensor<?x2xi64>) -> tensor<?x3xi64> | |
%303 = torch.aten.add.int %92, %289 : !torch.int, !torch.int -> !torch.int | |
%304 = torch_c.to_i64 %303 | |
%c120 = arith.constant 120 : index | |
%c120_i64 = arith.constant 120 : i64 | |
%305 = arith.addi %93, %c120_i64 : i64 | |
%306 = arith.cmpi sge, %93, %c0_i64_0 : i64 | |
%307 = arith.select %306, %93, %305 : i64 | |
%308 = arith.cmpi slt, %307, %c0_i64_0 : i64 | |
%309 = arith.select %308, %c0_i64_0, %307 : i64 | |
%310 = arith.cmpi sgt, %309, %c120_i64 : i64 | |
%311 = arith.select %310, %c120_i64, %309 : i64 | |
%312 = arith.index_cast %311 : i64 to index | |
%313 = arith.index_cast %304 : i64 to index | |
%314 = arith.cmpi slt, %313, %c0 : index | |
%315 = arith.addi %313, %c120 : index | |
%316 = arith.select %314, %315, %313 : index | |
%317 = arith.cmpi slt, %316, %c0 : index | |
%318 = arith.select %317, %c-1, %316 : index | |
%319 = arith.cmpi sgt, %318, %c120 : index | |
%320 = arith.select %319, %c120, %318 : index | |
%321 = arith.subi %320, %312 : index | |
%322 = arith.addi %321, %48 : index | |
%323 = arith.subi %322, %67 : index | |
%324 = arith.floordivsi %323, %48 : index | |
%325 = arith.cmpi slt, %324, %c0 : index | |
%326 = arith.select %325, %c0, %324 : index | |
%327 = arith.subi %c120, %c1 : index | |
%328 = tensor.empty() : tensor<120x3xi64> | |
%329 = linalg.fill ins(%c0_i64_0 : i64) outs(%328 : tensor<120x3xi64>) -> tensor<120x3xi64> | |
%330 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%94 : tensor<120x3xi64>) outs(%329 : tensor<120x3xi64>) { | |
^bb0(%in: i64, %out: i64): | |
%342 = linalg.index 0 : index | |
%343 = linalg.index 1 : index | |
%344 = arith.subi %327, %342 : index | |
%extracted_43 = tensor.extract %94[%344, %343] : tensor<120x3xi64> | |
linalg.yield %extracted_43 : i64 | |
} -> tensor<120x3xi64> | |
%331 = arith.muli %326, %78 : index | |
%c120_38 = arith.constant 120 : index | |
%332 = arith.subi %c120_38, %331 : index | |
%333 = arith.select %77, %332, %312 : index | |
%334 = arith.select %77, %330, %94 : tensor<120x3xi64> | |
%extracted_slice_39 = tensor.extract_slice %334[%333, %c0] [%326, %c3] [%78, %c1] : tensor<120x3xi64> to tensor<?x?xi64> | |
%cast_40 = tensor.cast %extracted_slice_39 : tensor<?x?xi64> to tensor<?x3xi64> | |
%dim_41 = tensor.dim %cast_40, %287 : tensor<?x3xi64> | |
%335 = arith.index_cast %dim_41 : index to i64 | |
%336 = torch_c.from_i64 %335 | |
%337 = torch_c.to_i64 %336 | |
%338 = arith.index_cast %337 : i64 to index | |
%cast_42 = tensor.cast %concat_37 : tensor<?x3xi64> to tensor<?x?xi64> | |
%inserted_slice = tensor.insert_slice %cast_42 into %94[%312, %c0] [%326, %c3] [%48, %c1] : tensor<?x?xi64> into tensor<120x3xi64> | |
%339 = torch_c.from_builtin_tensor %inserted_slice : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
%340 = torch_c.to_builtin_tensor %339 : !torch.vtensor<[120,3],si64> -> tensor<120x3xi64> | |
%341 = torch_c.to_i64 %303 | |
scf.yield %340, %341 : tensor<120x3xi64>, i64 | |
} | |
%85 = torch_c.from_i64 %84#1 | |
%86 = torch_c.from_builtin_tensor %84#0 : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
%87 = torch_c.to_builtin_tensor %86 : !torch.vtensor<[120,3],si64> -> tensor<120x3xi64> | |
%88 = torch_c.to_i64 %85 | |
scf.yield %87, %88 : tensor<120x3xi64>, i64 | |
} | |
%33 = torch_c.from_builtin_tensor %32#0 : tensor<120x3xi64> -> !torch.vtensor<[120,3],si64> | |
return %33 : !torch.vtensor<[120,3],si64> | |
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module { | |
func.func @main(%arg0: !torch.vtensor<[3,8,4],f32>, %arg1: !torch.vtensor<[3,5,8],f32>, %arg2: !torch.vtensor<[1],si64>, %arg3: !torch.vtensor<[1],f32>, %arg4: !torch.vtensor<[1],f32>) -> !torch.vtensor<[45,3],si64> attributes {torch.onnx_meta.ir_version = 10 : si64, torch.onnx_meta.opset_version = 21 : si64, torch.onnx_meta.producer_name = "", torch.onnx_meta.producer_version = ""} { | |
%none = torch.constant.none | |
%0 = torch.operator "onnx.NonMaxSuppression"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.center_point_box = 0 : si64} : (!torch.vtensor<[3,8,4],f32>, !torch.vtensor<[3,5,8],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[45,3],si64> | |
return %0 : !torch.vtensor<[45,3],si64> | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment