Skip to content

Rollup of 4 pull requests #140350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove noinline attribute and add alwaysinline after AD pass
  • Loading branch information
Shourya742 committed Apr 25, 2025
commit 48d05aac9b9aa488702b5519fd16121336468593
30 changes: 29 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ use crate::back::write::{
use crate::errors::{
DynamicLinkingWithLTO, LlvmError, LtoBitcodeFromRlib, LtoDisallowed, LtoDylib, LtoProcMacro,
};
use crate::llvm::AttributePlace::Function;
use crate::llvm::{self, build_string};
use crate::{LlvmCodegenBackend, ModuleLlvm};
use crate::{LlvmCodegenBackend, ModuleLlvm, SimpleCx, attributes};

/// We keep track of the computed LTO cache keys from the previous
/// session to determine which CGUs we can reuse.
Expand Down Expand Up @@ -666,6 +667,33 @@ pub(crate) fn run_pass_manager(
}

if cfg!(llvm_enzyme) && enable_ad && !thin {
let cx =
SimpleCx::new(module.module_llvm.llmod(), &module.module_llvm.llcx, cgcx.pointer_size);

for function in cx.get_functions() {
let enzyme_marker = CString::new("enzyme_marker").unwrap();
let marker_ptr = enzyme_marker.as_ptr();

if attributes::has_string_attr(function, marker_ptr) {
// Sanity check: Ensure 'noinline' is present before replacing it.
assert!(
!attributes::has_attr(function, Function, llvm::AttributeKind::NoInline),
"Expected __enzyme function to have 'noinline' before adding 'alwaysinline'"
);

attributes::remove_from_llfn(function, Function, llvm::AttributeKind::NoInline);
attributes::remove_string_attr_from_llfn(function, marker_ptr);

assert!(
!attributes::has_string_attr(function, marker_ptr),
"Expected function to not have 'enzyme_marker'"
);

let always_inline = llvm::AttributeKind::AlwaysInline.create_attr(cx.llcx);
attributes::apply_to_llfn(function, Function, &[always_inline]);
}
}

let opt_stage = llvm::OptStage::FatLTO;
let stage = write::AutodiffStage::PostAD;
if !config.autodiff.contains(&config::AutoDiff::NoPostopt) {
Expand Down
10 changes: 10 additions & 0 deletions compiler/rustc_codegen_llvm/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,16 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
llvm::LLVMMDStringInContext2(self.llcx(), name.as_ptr() as *const c_char, name.len())
})
}

pub(crate) fn get_functions(&self) -> Vec<&'ll Value> {
let mut functions = vec![];
let mut func = unsafe { llvm::LLVMGetFirstFunction(self.llmod()) };
while let Some(f) = func {
functions.push(f);
func = unsafe { llvm::LLVMGetNextFunction(f) }
}
functions
}
}

impl<'ll, 'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'ll, 'tcx> {
Expand Down
26 changes: 26 additions & 0 deletions compiler/rustc_codegen_llvm/src/llvm/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,32 @@ pub(crate) fn AddFunctionAttributes<'ll>(
}
}

pub(crate) fn HasAttributeAtIndex<'ll>(
llfn: &'ll Value,
idx: AttributePlace,
kind: AttributeKind,
) -> bool {
unsafe { LLVMRustHasAttributeAtIndex(llfn, idx.as_uint(), kind) }
}

pub(crate) fn HasStringAttribute<'ll>(llfn: &'ll Value, name: *const i8) -> bool {
unsafe { LLVMRustHasFnAttribute(llfn, name) }
}

pub(crate) fn RemoveStringAttrFromFn<'ll>(llfn: &'ll Value, name: *const i8) {
unsafe { LLVMRustRemoveFnAttribute(llfn, name) }
}

pub(crate) fn RemoveRustEnumAttributeAtIndex(
llfn: &Value,
place: AttributePlace,
kind: AttributeKind,
) {
unsafe {
LLVMRustRemoveEnumAttributeAtIndex(llfn, place.as_uint(), kind);
}
}

pub(crate) fn AddCallSiteAttributes<'ll>(
callsite: &'ll Value,
idx: AttributePlace,
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/src/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,10 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
(**self).borrow().llcx
}

pub(crate) fn llmod(&self) -> &'ll llvm::Module {
(**self).borrow().llmod
}

pub(crate) fn isize_ty(&self) -> &'ll Type {
(**self).borrow().isize_ty
}
Expand Down
23 changes: 23 additions & 0 deletions tests/codegen/autodiff/inline.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat -Zautodiff=NoPostopt
//@ no-prefer-dynamic
//@ needs-enzyme

#![feature(autodiff)]

use std::autodiff::autodiff;

#[autodiff(d_square, Reverse, Duplicated, Active)]
fn square(x: &f64) -> f64 {
x * x
}

// CHECK: ; inline::d_square
// CHECK-NEXT: ; Function Attrs: alwaysinline
// CHECK-NOT: noinline
// CHECK-NEXT: define internal fastcc void @_ZN6inline8d_square17h021c74e92c259cdeE
fn main() {
let x = std::hint::black_box(3.0);
let mut dx1 = std::hint::black_box(1.0);
let _ = d_square(&x, &mut dx1, 1.0);
assert_eq!(dx1, 6.0);
}
Loading