Skip to content

[HLSL][RootSignature] Add mandatory parameters for RootConstants #138002

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 9, 2025

Conversation

inbelic
Copy link
Contributor

@inbelic inbelic commented Apr 30, 2025

  • defines the parseRootConstantParams function and adds handling for the mandatory arguments of num32BitConstants and bReg

  • adds corresponding unit tests

Part two of implementing #126576

@llvmbot llvmbot added clang Clang issues not falling into any other category clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Apr 30, 2025
@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-clang

Author: Finn Plummer (inbelic)

Changes
  • defines the parseRootConstantParams function and adds handling for the mandatory arguments of num32BitConstants and bReg

  • adds corresponding unit tests

Part two of implementing #126576


Full diff: https://github.com/llvm/llvm-project/pull/138002.diff

4 Files Affected:

  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+8-2)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+65-3)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+12-2)
  • (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+4-1)
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index efa735ea03d94..0f05b05ed4df6 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -77,8 +77,14 @@ class RootSignatureParser {
   parseDescriptorTableClause();
 
   /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  /// order and only exactly once. `ParsedClauseParams` denotes the current
-  /// state of parsed params
+  /// order and only exactly once. The following methods define a
+  /// `Parsed.*Params` struct to denote the current state of parsed params
+  struct ParsedConstantParams {
+    std::optional<llvm::hlsl::rootsig::Register> Reg;
+    std::optional<uint32_t> Num32BitConstants;
+  };
+  std::optional<ParsedConstantParams> parseRootConstantParams();
+
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
     std::optional<uint32_t> NumDescriptors;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 48d3e38b0519d..2ce8e6e5cca98 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -57,6 +57,27 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
 
   RootConstants Constants;
 
+  auto Params = parseRootConstantParams();
+  if (!Params.has_value())
+    return std::nullopt;
+
+  // Check mandatory parameters were provided
+  if (!Params->Num32BitConstants.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::kw_num32BitConstants;
+    return std::nullopt;
+  }
+
+  Constants.Num32BitConstants = Params->Num32BitConstants.value();
+
+  if (!Params->Reg.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::bReg;
+    return std::nullopt;
+  }
+
+  Constants.Reg = Params->Reg.value();
+
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/TokenKind::kw_RootConstants))
@@ -187,14 +208,55 @@ RootSignatureParser::parseDescriptorTableClause() {
   return Clause;
 }
 
+// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+// order and only exactly once. The following methods will parse through as
+// many arguments as possible reporting an error if a duplicate is seen.
+std::optional<RootSignatureParser::ParsedConstantParams>
+RootSignatureParser::parseRootConstantParams() {
+  assert(CurToken.TokKind == TokenKind::pu_l_paren &&
+         "Expects to only be invoked starting at given token");
+
+  ParsedConstantParams Params;
+  do {
+    // `num32BitConstants` `=` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      if (Params.Num32BitConstants.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return std::nullopt;
+
+      auto Num32BitConstants = parseUIntParam();
+      if (!Num32BitConstants.has_value())
+        return std::nullopt;
+      Params.Num32BitConstants = Num32BitConstants;
+    }
+
+    // `b` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      if (Params.Reg.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+      auto Reg = parseRegister();
+      if (!Reg.has_value())
+        return std::nullopt;
+      Params.Reg = Reg;
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+  return Params;
+}
+
 std::optional<RootSignatureParser::ParsedClauseParams>
 RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
          "Expects to only be invoked starting at given token");
 
-  // Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  // order and only exactly once. Parse through as many arguments as possible
-  // reporting an error if a duplicate is seen.
   ParsedClauseParams Params;
   do {
     // ( `b` | `t` | `u` | `s`) POS_INT
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 0a7d8ac86cc5f..336868b579866 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -254,7 +254,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
 
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
   const llvm::StringLiteral Source = R"cc(
-    RootConstants()
+    RootConstants(num32BitConstants = 1, b0),
+    RootConstants(b42, num32BitConstants = 4294967295)
   )cc";
 
   TrivialModuleLoader ModLoader;
@@ -270,10 +271,19 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
 
   ASSERT_FALSE(Parser.parse());
 
-  ASSERT_EQ(Elements.size(), 1u);
+  ASSERT_EQ(Elements.size(), 2u);
 
   RootElement Elem = Elements[0];
   ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 1u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 0u);
+
+  Elem = Elements[1];
+  ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 4294967295u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 42u);
 
   ASSERT_TRUE(Consumer->isSatisfied());
 }
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 05735fa75b318..a3f98a9f1944f 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -55,7 +55,10 @@ struct Register {
 };
 
 // Models the parameter values of root constants
-struct RootConstants {};
+struct RootConstants {
+  uint32_t Num32BitConstants;
+  Register Reg;
+};
 
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {

@llvmbot
Copy link
Member

llvmbot commented Apr 30, 2025

@llvm/pr-subscribers-hlsl

Author: Finn Plummer (inbelic)

Changes
  • defines the parseRootConstantParams function and adds handling for the mandatory arguments of num32BitConstants and bReg

  • adds corresponding unit tests

Part two of implementing #126576


Full diff: https://github.com/llvm/llvm-project/pull/138002.diff

4 Files Affected:

  • (modified) clang/include/clang/Parse/ParseHLSLRootSignature.h (+8-2)
  • (modified) clang/lib/Parse/ParseHLSLRootSignature.cpp (+65-3)
  • (modified) clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp (+12-2)
  • (modified) llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h (+4-1)
diff --git a/clang/include/clang/Parse/ParseHLSLRootSignature.h b/clang/include/clang/Parse/ParseHLSLRootSignature.h
index efa735ea03d94..0f05b05ed4df6 100644
--- a/clang/include/clang/Parse/ParseHLSLRootSignature.h
+++ b/clang/include/clang/Parse/ParseHLSLRootSignature.h
@@ -77,8 +77,14 @@ class RootSignatureParser {
   parseDescriptorTableClause();
 
   /// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  /// order and only exactly once. `ParsedClauseParams` denotes the current
-  /// state of parsed params
+  /// order and only exactly once. The following methods define a
+  /// `Parsed.*Params` struct to denote the current state of parsed params
+  struct ParsedConstantParams {
+    std::optional<llvm::hlsl::rootsig::Register> Reg;
+    std::optional<uint32_t> Num32BitConstants;
+  };
+  std::optional<ParsedConstantParams> parseRootConstantParams();
+
   struct ParsedClauseParams {
     std::optional<llvm::hlsl::rootsig::Register> Reg;
     std::optional<uint32_t> NumDescriptors;
diff --git a/clang/lib/Parse/ParseHLSLRootSignature.cpp b/clang/lib/Parse/ParseHLSLRootSignature.cpp
index 48d3e38b0519d..2ce8e6e5cca98 100644
--- a/clang/lib/Parse/ParseHLSLRootSignature.cpp
+++ b/clang/lib/Parse/ParseHLSLRootSignature.cpp
@@ -57,6 +57,27 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {
 
   RootConstants Constants;
 
+  auto Params = parseRootConstantParams();
+  if (!Params.has_value())
+    return std::nullopt;
+
+  // Check mandatory parameters were provided
+  if (!Params->Num32BitConstants.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::kw_num32BitConstants;
+    return std::nullopt;
+  }
+
+  Constants.Num32BitConstants = Params->Num32BitConstants.value();
+
+  if (!Params->Reg.has_value()) {
+    getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_missing_param)
+        << TokenKind::bReg;
+    return std::nullopt;
+  }
+
+  Constants.Reg = Params->Reg.value();
+
   if (consumeExpectedToken(TokenKind::pu_r_paren,
                            diag::err_hlsl_unexpected_end_of_params,
                            /*param of=*/TokenKind::kw_RootConstants))
@@ -187,14 +208,55 @@ RootSignatureParser::parseDescriptorTableClause() {
   return Clause;
 }
 
+// Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
+// order and only exactly once. The following methods will parse through as
+// many arguments as possible reporting an error if a duplicate is seen.
+std::optional<RootSignatureParser::ParsedConstantParams>
+RootSignatureParser::parseRootConstantParams() {
+  assert(CurToken.TokKind == TokenKind::pu_l_paren &&
+         "Expects to only be invoked starting at given token");
+
+  ParsedConstantParams Params;
+  do {
+    // `num32BitConstants` `=` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::kw_num32BitConstants)) {
+      if (Params.Num32BitConstants.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+
+      if (consumeExpectedToken(TokenKind::pu_equal))
+        return std::nullopt;
+
+      auto Num32BitConstants = parseUIntParam();
+      if (!Num32BitConstants.has_value())
+        return std::nullopt;
+      Params.Num32BitConstants = Num32BitConstants;
+    }
+
+    // `b` POS_INT
+    if (tryConsumeExpectedToken(TokenKind::bReg)) {
+      if (Params.Reg.has_value()) {
+        getDiags().Report(CurToken.TokLoc, diag::err_hlsl_rootsig_repeat_param)
+            << CurToken.TokKind;
+        return std::nullopt;
+      }
+      auto Reg = parseRegister();
+      if (!Reg.has_value())
+        return std::nullopt;
+      Params.Reg = Reg;
+    }
+  } while (tryConsumeExpectedToken(TokenKind::pu_comma));
+
+  return Params;
+}
+
 std::optional<RootSignatureParser::ParsedClauseParams>
 RootSignatureParser::parseDescriptorTableClauseParams(TokenKind RegType) {
   assert(CurToken.TokKind == TokenKind::pu_l_paren &&
          "Expects to only be invoked starting at given token");
 
-  // Parameter arguments (eg. `bReg`, `space`, ...) can be specified in any
-  // order and only exactly once. Parse through as many arguments as possible
-  // reporting an error if a duplicate is seen.
   ParsedClauseParams Params;
   do {
     // ( `b` | `t` | `u` | `s`) POS_INT
diff --git a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
index 0a7d8ac86cc5f..336868b579866 100644
--- a/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
+++ b/clang/unittests/Parse/ParseHLSLRootSignatureTest.cpp
@@ -254,7 +254,8 @@ TEST_F(ParseHLSLRootSignatureTest, ValidSamplerFlagsTest) {
 
 TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
   const llvm::StringLiteral Source = R"cc(
-    RootConstants()
+    RootConstants(num32BitConstants = 1, b0),
+    RootConstants(b42, num32BitConstants = 4294967295)
   )cc";
 
   TrivialModuleLoader ModLoader;
@@ -270,10 +271,19 @@ TEST_F(ParseHLSLRootSignatureTest, ValidParseRootConsantsTest) {
 
   ASSERT_FALSE(Parser.parse());
 
-  ASSERT_EQ(Elements.size(), 1u);
+  ASSERT_EQ(Elements.size(), 2u);
 
   RootElement Elem = Elements[0];
   ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 1u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 0u);
+
+  Elem = Elements[1];
+  ASSERT_TRUE(std::holds_alternative<RootConstants>(Elem));
+  ASSERT_EQ(std::get<RootConstants>(Elem).Num32BitConstants, 4294967295u);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.ViewType, RegisterType::BReg);
+  ASSERT_EQ(std::get<RootConstants>(Elem).Reg.Number, 42u);
 
   ASSERT_TRUE(Consumer->isSatisfied());
 }
diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
index 05735fa75b318..a3f98a9f1944f 100644
--- a/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
+++ b/llvm/include/llvm/Frontend/HLSL/HLSLRootSignature.h
@@ -55,7 +55,10 @@ struct Register {
 };
 
 // Models the parameter values of root constants
-struct RootConstants {};
+struct RootConstants {
+  uint32_t Num32BitConstants;
+  Register Reg;
+};
 
 // Models the end of a descriptor table and stores its visibility
 struct DescriptorTable {

@inbelic inbelic force-pushed the inbelic/rs-mand-root-const branch from f801180 to 15857bf Compare April 30, 2025 18:10
Copy link
Contributor

@joaosaffran joaosaffran left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add tests to verify error scenarios as well?

@@ -57,6 +57,27 @@ std::optional<RootConstants> RootSignatureParser::parseRootConstants() {

RootConstants Constants;

auto Params = parseRootConstantParams();
if (!Params.has_value())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (!Params.has_value())
if (Params)

return std::nullopt;
}

Constants.Num32BitConstants = Params->Num32BitConstants.value();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: You can dereference an optional to get the value

Suggested change
Constants.Num32BitConstants = Params->Num32BitConstants.value();
Constants.Num32BitConstants = *Params->Num32BitConstants;

inbelic and others added 4 commits May 9, 2025 16:26
- defines the `parseRootConstantParams` function and adds handling for
the mandatory arguments of `num32BitConstants` and `bReg`

- adds corresponding unit tests

Part two of implementing
Co-authored-by: Ashley Coleman <[email protected]>
@inbelic inbelic changed the base branch from users/inbelic/pr-137999 to main May 9, 2025 16:26
@inbelic inbelic force-pushed the inbelic/rs-mand-root-const branch from 200be1f to ccb45bd Compare May 9, 2025 16:28
@inbelic inbelic merged commit 5494349 into llvm:main May 9, 2025
7 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants