Skip to content

[TableGen] Validate the shift amount for !srl, !shl, and !sra operators. #132492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

topperc
Copy link
Collaborator

@topperc topperc commented Mar 21, 2025

The C operator has undefined behavior for out of bounds shifts so we should check this.

The C operator has undefined behavior for out of bounds shifts
so we should check this.
@topperc topperc requested a review from jurahul March 21, 2025 23:38
@llvmbot
Copy link
Member

llvmbot commented Mar 21, 2025

@llvm/pr-subscribers-tablegen

Author: Craig Topper (topperc)

Changes

The C operator has undefined behavior for out of bounds shifts so we should check this.


Full diff: https://github.com/llvm/llvm-project/pull/132492.diff

1 Files Affected:

  • (modified) llvm/lib/TableGen/Record.cpp (+8-2)
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index c5b9b670b6f42..655c4078697f3 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -1537,7 +1537,13 @@ const Init *BinOpInit::Fold(const Record *CurRec) const {
     if (LHSi && RHSi) {
       int64_t LHSv = LHSi->getValue(), RHSv = RHSi->getValue();
       int64_t Result;
-      switch (getOpcode()) {
+
+      unsigned Opc = getOpcode();
+      if ((Opc == SHL || Opc == SRA || Opc == SRL) && (RHSv < 0 || RHSv >= 64))
+        PrintFatalError(CurRec->getLoc(),
+                        "Illegal operation: out of bounds shift");
+
+      switch (Opc) {
       default: llvm_unreachable("Bad opcode!");
       case ADD: Result = LHSv + RHSv; break;
       case SUB: Result = LHSv - RHSv; break;
@@ -1556,7 +1562,7 @@ const Init *BinOpInit::Fold(const Record *CurRec) const {
       case OR:  Result = LHSv | RHSv; break;
       case XOR: Result = LHSv ^ RHSv; break;
       case SHL: Result = (uint64_t)LHSv << (uint64_t)RHSv; break;
-      case SRA: Result = LHSv >> RHSv; break;
+      case SRA: Result = LHSv >> (uint64_t)RHSv; break;
       case SRL: Result = (uint64_t)LHSv >> (uint64_t)RHSv; break;
       }
       return IntInit::get(getRecordKeeper(), Result);

Copy link

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff b858ba0f6597c66e5c276ca9e2564ca27e7e28e7 4277124e8acb8efed73dee9bc32a1744c807dcdf --extensions cpp -- llvm/lib/TableGen/Record.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/TableGen/Record.cpp b/llvm/lib/TableGen/Record.cpp
index 655c407869..78f9f2368d 100644
--- a/llvm/lib/TableGen/Record.cpp
+++ b/llvm/lib/TableGen/Record.cpp
@@ -1562,7 +1562,9 @@ const Init *BinOpInit::Fold(const Record *CurRec) const {
       case OR:  Result = LHSv | RHSv; break;
       case XOR: Result = LHSv ^ RHSv; break;
       case SHL: Result = (uint64_t)LHSv << (uint64_t)RHSv; break;
-      case SRA: Result = LHSv >> (uint64_t)RHSv; break;
+      case SRA:
+        Result = LHSv >> (uint64_t)RHSv;
+        break;
       case SRL: Result = (uint64_t)LHSv >> (uint64_t)RHSv; break;
       }
       return IntInit::get(getRecordKeeper(), Result);

@jurahul
Copy link
Contributor

jurahul commented Mar 22, 2025

Can you also add unit test?


unsigned Opc = getOpcode();
if ((Opc == SHL || Opc == SRA || Opc == SRL) && (RHSv < 0 || RHSv >= 64))
PrintFatalError(CurRec->getLoc(),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the duplication is mildly ugly, this would seem more natural if sunk into the case blocks below.

@@ -1556,7 +1562,7 @@ const Init *BinOpInit::Fold(const Record *CurRec) const {
case OR: Result = LHSv | RHSv; break;
case XOR: Result = LHSv ^ RHSv; break;
case SHL: Result = (uint64_t)LHSv << (uint64_t)RHSv; break;
case SRA: Result = LHSv >> RHSv; break;
case SRA: Result = LHSv >> (uint64_t)RHSv; break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change should be a nop given the newly added bounds check, right?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants