-
Notifications
You must be signed in to change notification settings - Fork 13.5k
[MLIR][NVVM] Add dot.accumulate.4way
OP
#139043
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-llvm Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds the PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a Full diff: https://github.com/llvm/llvm-project/pull/139043.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6540273b216e3..85b3e80711018 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3444,6 +3444,54 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// NVVM dp4a Op
+//===----------------------------------------------------------------------===//
+
+def NVVM_Dp4aOp : NVVM_Op<"dp4a"> {
+ let summary = "Four-way byte dot product-accumulate instruction.";
+ let description = [{
+ Performs a four-way byte dot-product which is accumulated in a 32-bit
+ result.
+ Operand `a` and `b` can be passed either as packed 32-bit inputs holding
+ 4 byte-inputs for the dot product, or as vectors of 4 i8 elements.
+ The `asigned` and `bsigned` unit attributes specify whether the
+ individual byte inputs in operands `a` and `b` are signed or unsigned
+ respectively.
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
+ treated as holding a signed integer if any of `a` or `b` are signed.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
+ }];
+
+ let arguments = (ins
+ AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$a,
+ AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$b,
+ I32:$c,
+ DefaultValuedAttr<UnitAttr, "false">:$a_signed,
+ DefaultValuedAttr<UnitAttr, "false">:$b_signed
+ );
+
+ let results = (outs I32:$res);
+
+ let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($a) `,` type($b)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(bool a_signed, bool b_signed);
+ }];
+
+ string llvmBuilder = [{
+ auto id = NVVM::Dp4aOp::getIntrinsicID($a_signed, $b_signed);
+ llvm::Value* argA = $a;
+ llvm::Value* argB = $b;
+ if (!op.getA().getType().isInteger(32))
+ argA = builder.CreateBitCast(argA, llvm::Type::getInt32Ty(builder.getContext()));
+ if (!op.getB().getType().isInteger(32))
+ argB = builder.CreateBitCast(argB, llvm::Type::getInt32Ty(builder.getContext()));
+ $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3c3731a63e268..a4100d7ce3bef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1590,6 +1590,14 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+#define GET_DP4A_ID(a_sign, is_b_signed) \
+ is_b_signed ? llvm::Intrinsic::nvvm_idp4a_##a_sign##_s \
+ : llvm::Intrinsic::nvvm_idp4a_##a_sign##_u
+
+llvm::Intrinsic::ID Dp4aOp::getIntrinsicID(bool a_signed, bool b_signed) {
+ return a_signed ? GET_DP4A_ID(s, b_signed) : GET_DP4A_ID(u, b_signed);
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index d3915492c38a0..53ef034821611 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -578,6 +578,19 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
return
}
+// CHECK-LABEL: @dp4a
+func.func @dp4a_packed(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : i32, i32
+ %0 = nvvm.dp4a %a, %b, %c: i32, i32
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dp4a %a_vec, %b_vec, %c: vector<4xi8>, vector<4xi8>
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : i32, i32
+ %2 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dp4a %a_vec, %b_vec, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
+ return
+}
+
// -----
// Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 3a0713f2feee8..4a116f6db37e5 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -844,3 +844,39 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
llvm.return
}
+
+// -----
+// CHECK-LABEL: @nvvm_dp4a_packed
+llvm.func @nvvm_dp4a_packed(%a: i32, %b: i32, %c: i32) {
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.dp4a %a, %b, %c: i32, i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.dp4a %a, %b, %c {a_signed}: i32, i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %2 = nvvm.dp4a %a, %b, %c {b_signed}: i32, i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_dp4a_vec
+llvm.func @nvvm_dp4a_vec(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %0 = nvvm.dp4a %a, %b, %c: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %1 = nvvm.dp4a %a, %b, %c {a_signed}: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %2 = nvvm.dp4a %a, %b, %c {b_signed}: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
+ llvm.return
+}
|
@llvm/pr-subscribers-mlir Author: Srinivasa Ravi (Wolfram70) ChangesThis change adds the PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a Full diff: https://github.com/llvm/llvm-project/pull/139043.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 6540273b216e3..85b3e80711018 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -3444,6 +3444,54 @@ def NVVM_Tcgen05StOp : NVVM_Op<"tcgen05.st"> {
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// NVVM dp4a Op
+//===----------------------------------------------------------------------===//
+
+def NVVM_Dp4aOp : NVVM_Op<"dp4a"> {
+ let summary = "Four-way byte dot product-accumulate instruction.";
+ let description = [{
+ Performs a four-way byte dot-product which is accumulated in a 32-bit
+ result.
+ Operand `a` and `b` can be passed either as packed 32-bit inputs holding
+ 4 byte-inputs for the dot product, or as vectors of 4 i8 elements.
+ The `asigned` and `bsigned` unit attributes specify whether the
+ individual byte inputs in operands `a` and `b` are signed or unsigned
+ respectively.
+ Operand `c` is a 32-bit integer to which the result is accumulated. It is
+ treated as holding a signed integer if any of `a` or `b` are signed.
+
+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a)
+ }];
+
+ let arguments = (ins
+ AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$a,
+ AnyTypeOf<[I32, VectorOfLengthAndType<[4], [I8]>]>:$b,
+ I32:$c,
+ DefaultValuedAttr<UnitAttr, "false">:$a_signed,
+ DefaultValuedAttr<UnitAttr, "false">:$b_signed
+ );
+
+ let results = (outs I32:$res);
+
+ let assemblyFormat = "$a `,` $b `,` $c attr-dict `:` type($a) `,` type($b)";
+
+ let extraClassDeclaration = [{
+ static llvm::Intrinsic::ID getIntrinsicID(bool a_signed, bool b_signed);
+ }];
+
+ string llvmBuilder = [{
+ auto id = NVVM::Dp4aOp::getIntrinsicID($a_signed, $b_signed);
+ llvm::Value* argA = $a;
+ llvm::Value* argB = $b;
+ if (!op.getA().getType().isInteger(32))
+ argA = builder.CreateBitCast(argA, llvm::Type::getInt32Ty(builder.getContext()));
+ if (!op.getB().getType().isInteger(32))
+ argB = builder.CreateBitCast(argB, llvm::Type::getInt32Ty(builder.getContext()));
+ $res = createIntrinsicCall(builder, id, {argA, argB, $c});
+ }];
+}
+
//===----------------------------------------------------------------------===//
// NVVM target attribute.
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index 3c3731a63e268..a4100d7ce3bef 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -1590,6 +1590,14 @@ static void nvvmInferResultRanges(Operation *op, Value result,
}
}
+#define GET_DP4A_ID(a_sign, is_b_signed) \
+ is_b_signed ? llvm::Intrinsic::nvvm_idp4a_##a_sign##_s \
+ : llvm::Intrinsic::nvvm_idp4a_##a_sign##_u
+
+llvm::Intrinsic::ID Dp4aOp::getIntrinsicID(bool a_signed, bool b_signed) {
+ return a_signed ? GET_DP4A_ID(s, b_signed) : GET_DP4A_ID(u, b_signed);
+}
+
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir
index d3915492c38a0..53ef034821611 100644
--- a/mlir/test/Dialect/LLVMIR/nvvm.mlir
+++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir
@@ -578,6 +578,19 @@ func.func @st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size: i64)
return
}
+// CHECK-LABEL: @dp4a
+func.func @dp4a_packed(%a: i32, %a_vec: vector<4xi8>, %b: i32, %b_vec: vector<4xi8>, %c: i32) {
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : i32, i32
+ %0 = nvvm.dp4a %a, %b, %c: i32, i32
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} : vector<4xi8>, vector<4xi8>
+ %1 = nvvm.dp4a %a_vec, %b_vec, %c: vector<4xi8>, vector<4xi8>
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : i32, i32
+ %2 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
+ // CHECK: nvvm.dp4a %{{.*}}, %{{.*}}, %{{.*}} {a_signed, b_signed} : vector<4xi8>, vector<4xi8>
+ %3 = nvvm.dp4a %a_vec, %b_vec, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
+ return
+}
+
// -----
// Just check these don't emit errors.
diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir
index 3a0713f2feee8..4a116f6db37e5 100644
--- a/mlir/test/Target/LLVMIR/nvvmir.mlir
+++ b/mlir/test/Target/LLVMIR/nvvmir.mlir
@@ -844,3 +844,39 @@ llvm.func @nvvm_st_bulk(%addr_gen: !llvm.ptr, %addr_shared: !llvm.ptr<3>, %size:
nvvm.st.bulk %addr_shared, size = %size, init = 0: !llvm.ptr<3>
llvm.return
}
+
+// -----
+// CHECK-LABEL: @nvvm_dp4a_packed
+llvm.func @nvvm_dp4a_packed(%a: i32, %b: i32, %c: i32) {
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %0 = nvvm.dp4a %a, %b, %c: i32, i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %1 = nvvm.dp4a %a, %b, %c {a_signed}: i32, i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %2 = nvvm.dp4a %a, %b, %c {b_signed}: i32, i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
+ %3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: i32, i32
+ llvm.return
+}
+
+// -----
+// CHECK-LABEL: @nvvm_dp4a_vec
+llvm.func @nvvm_dp4a_vec(%a: vector<4xi8>, %b: vector<4xi8>, %c: i32) {
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %0 = nvvm.dp4a %a, %b, %c: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.u(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %1 = nvvm.dp4a %a, %b, %c {a_signed}: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.u.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %2 = nvvm.dp4a %a, %b, %c {b_signed}: vector<4xi8>, vector<4xi8>
+ // CHECK: %[[a_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: %[[b_cast:.*]] = bitcast <4 x i8> %{{.*}} to i32
+ // CHECK: call i32 @llvm.nvvm.idp4a.s.s(i32 %[[a_cast]], i32 %[[b_cast]], i32 %{{.*}})
+ %3 = nvvm.dp4a %a, %b, %c {a_signed, b_signed}: vector<4xi8>, vector<4xi8>
+ llvm.return
+}
|
bddeeba
to
9bd957d
Compare
Mostly LGTM with a few minor changes requested |
9bd957d
to
2b259fc
Compare
✅ With the latest revision this PR passed the C/C++ code formatter. |
8174f84
to
fdaad0e
Compare
fdaad0e
to
afaddd5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor things left. I approve it, you can handle them and land it.
Can we also change the PR title so it matches with the new OP name?
afaddd5
to
698733e
Compare
dot.accumulate.4way
OP
e76be50
to
c8914d1
Compare
This change adds the `dp4a` Op to the NVVM dialect to perform four-way byte dot product-accumulate operation. For more information, see PTX ISA: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a
c8914d1
to
70cf353
Compare
This change adds the
dot.accumulate.4way
Op to the NVVM dialect to perform four-way byte dot product-accumulate operation.PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-dp4a