diff --git a/backends/arm/operators/op_bmm.py b/backends/arm/operators/op_bmm.py index 6dc0ec8002d..69c8283bc8c 100644 --- a/backends/arm/operators/op_bmm.py +++ b/backends/arm/operators/op_bmm.py @@ -37,18 +37,24 @@ def define_node( inputs: List[TosaArg], output: TosaArg, ) -> None: + if inputs[0].dtype != inputs[1].dtype or inputs[0].dtype != output.dtype: + raise TypeError( + f"All IO needs to have the same data type, got: " + f"{inputs[0].dtype=}, {inputs[1].dtype=} and {output.dtype=}" + ) - assert inputs[0].dtype == inputs[1].dtype, "Both inputs must be of same type" - assert inputs[0].dtype in [ - ts.DType.INT8, - ts.DType.FP32, - ], "Only int8 and float32 supported" - # aten.bmm maps directly to MATMUL # NOTE: For now, only INT8 & FP32 is supported + supported_dtypes = [ts.DType.INT8, ts.DType.FP32] + for input in inputs: + if input.dtype not in supported_dtypes: + raise TypeError( + f'IO data type needs to be {supported_dtypes}, got "{input.dtype}"' + ) + + # aten.bmm maps directly to MATMUL # For INT8, we need to get the zero points and add an intermediate tensor # for a later rescale. - if inputs[0].dtype == ts.DType.INT8: input_qparams = get_input_qparams(node) input0_zp = input_qparams[0].zp