From a68ae0cbc1d923b69ef690f32b651ecf583ff27f Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sat, 5 Apr 2025 03:10:19 -0400 Subject: [PATCH 1/2] working dupv and dupvonly for fwd mode --- .../rustc_ast/src/expand/autodiff_attrs.rs | 40 ++++++++++++------- compiler/rustc_builtin_macros/src/autodiff.rs | 23 ++++++++--- compiler/rustc_codegen_llvm/src/builder.rs | 2 +- .../src/builder/autodiff.rs | 38 +++++++++++++++--- .../src/partitioning/autodiff.rs | 32 ++++++++++++++- 5 files changed, 107 insertions(+), 28 deletions(-) diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 13a7c5a180576..2f918faaf752b 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -50,8 +50,16 @@ pub enum DiffActivity { /// with it. Dual, /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. It expects the shadow argument to be `width` times larger than the original + /// input/output. + Dualv, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument /// with it. Drop the code which updates the original input/output for maximum performance. DualOnly, + /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument + /// with it. Drop the code which updates the original input/output for maximum performance. + /// It expects the shadow argument to be `width` times larger than the original input/output. + DualvOnly, /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. Duplicated, /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument. @@ -59,7 +67,15 @@ pub enum DiffActivity { DuplicatedOnly, /// All Integers must be Const, but these are used to mark the integer which represents the /// length of a slice/vec. This is used for safety checks on slices. - FakeActivitySize, + /// The integer (if given) specifies the size of the slice element in bytes. + FakeActivitySize(Option), +} + +impl DiffActivity { + pub fn is_dual_or_const(&self) -> bool { + use DiffActivity::*; + matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const) + } } /// We generate one of these structs for each `#[autodiff(...)]` attribute. #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] @@ -131,11 +147,7 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { match mode { DiffMode::Error => false, DiffMode::Source => false, - DiffMode::Forward => { - activity == DiffActivity::Dual - || activity == DiffActivity::DualOnly - || activity == DiffActivity::Const - } + DiffMode::Forward => activity.is_dual_or_const(), DiffMode::Reverse => { activity == DiffActivity::Const || activity == DiffActivity::Active @@ -153,10 +165,8 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool { pub fn valid_ty_for_activity(ty: &P, activity: DiffActivity) -> bool { use DiffActivity::*; // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it. - if matches!(activity, Const) { - return true; - } - if matches!(activity, Dual | DualOnly) { + // Dual variants also support all types. + if activity.is_dual_or_const() { return true; } // FIXME(ZuseZ4) We should make this more robust to also @@ -172,9 +182,7 @@ pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool { return match mode { DiffMode::Error => false, DiffMode::Source => false, - DiffMode::Forward => { - matches!(activity, Dual | DualOnly | Const) - } + DiffMode::Forward => activity.is_dual_or_const(), DiffMode::Reverse => { matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const) } @@ -189,10 +197,12 @@ impl Display for DiffActivity { DiffActivity::Active => write!(f, "Active"), DiffActivity::ActiveOnly => write!(f, "ActiveOnly"), DiffActivity::Dual => write!(f, "Dual"), + DiffActivity::Dualv => write!(f, "Dualv"), DiffActivity::DualOnly => write!(f, "DualOnly"), + DiffActivity::DualvOnly => write!(f, "DualvOnly"), DiffActivity::Duplicated => write!(f, "Duplicated"), DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"), - DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"), + DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s), } } } @@ -220,7 +230,9 @@ impl FromStr for DiffActivity { "ActiveOnly" => Ok(DiffActivity::ActiveOnly), "Const" => Ok(DiffActivity::Const), "Dual" => Ok(DiffActivity::Dual), + "Dualv" => Ok(DiffActivity::Dualv), "DualOnly" => Ok(DiffActivity::DualOnly), + "DualvOnly" => Ok(DiffActivity::DualvOnly), "Duplicated" => Ok(DiffActivity::Duplicated), "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly), _ => Err(()), diff --git a/compiler/rustc_builtin_macros/src/autodiff.rs b/compiler/rustc_builtin_macros/src/autodiff.rs index 351413dea493c..44ee78ebe6883 100644 --- a/compiler/rustc_builtin_macros/src/autodiff.rs +++ b/compiler/rustc_builtin_macros/src/autodiff.rs @@ -799,8 +799,19 @@ mod llvm_enzyme { d_inputs.push(shadow_arg.clone()); } } - DiffActivity::Dual | DiffActivity::DualOnly => { - for i in 0..x.width { + DiffActivity::Dual + | DiffActivity::DualOnly + | DiffActivity::Dualv + | DiffActivity::DualvOnly => { + // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause + // Enzyme to not expect N arguments, but one argument (which is instead larger). + let iterations = + if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) { + 1 + } else { + x.width + }; + for i in 0..iterations { let mut shadow_arg = arg.clone(); let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind { ident.name @@ -823,7 +834,7 @@ mod llvm_enzyme { DiffActivity::Const => { // Nothing to do here. } - DiffActivity::None | DiffActivity::FakeActivitySize => { + DiffActivity::None | DiffActivity::FakeActivitySize(_) => { panic!("Should not happen"); } } @@ -887,8 +898,8 @@ mod llvm_enzyme { } }; - if let DiffActivity::Dual = x.ret_activity { - let kind = if x.width == 1 { + if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) { + let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) { // Dual can only be used for f32/f64 ret. // In that case we return now a tuple with two floats. TyKind::Tup(thin_vec![ty.clone(), ty.clone()]) @@ -903,7 +914,7 @@ mod llvm_enzyme { let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None }); d_decl.output = FnRetTy::Ty(ty); } - if let DiffActivity::DualOnly = x.ret_activity { + if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) { // No need to change the return type, // we will just return the shadow in place of the primal return. // However, if we have a width > 1, then we don't return -> T, but -> [T; width] diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 35134e9f5a050..27f7f95f100e2 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -123,7 +123,7 @@ impl<'a, 'll, CX: Borrow>> GenericBuilder<'a, 'll, CX> { /// Empty string, to be used where LLVM expects an instruction name, indicating /// that the instruction is to be left unnamed (i.e. numbered, in textual IR). // FIXME(eddyb) pass `&CStr` directly to FFI once it's a thin pointer. -const UNNAMED: *const c_char = c"".as_ptr(); +pub(crate) const UNNAMED: *const c_char = c"".as_ptr(); impl<'ll, CX: Borrow>> BackendTypes for GenericBuilder<'_, 'll, CX> { type Value = as BackendTypes>::Value; diff --git a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs index 5e7ef27143b14..e7c071d05aa8f 100644 --- a/compiler/rustc_codegen_llvm/src/builder/autodiff.rs +++ b/compiler/rustc_codegen_llvm/src/builder/autodiff.rs @@ -10,7 +10,7 @@ use rustc_middle::bug; use tracing::{debug, trace}; use crate::back::write::llvm_err; -use crate::builder::SBuilder; +use crate::builder::{SBuilder, UNNAMED}; use crate::context::SimpleCx; use crate::declare::declare_simple_fn; use crate::errors::{AutoDiffWithoutEnable, LlvmError}; @@ -51,6 +51,7 @@ fn has_sret(fnc: &Value) -> bool { // using iterators and peek()? fn match_args_from_caller_to_enzyme<'ll>( cx: &SimpleCx<'ll>, + builder: &SBuilder<'ll, 'll>, width: u32, args: &mut Vec<&'ll llvm::Value>, inputs: &[DiffActivity], @@ -78,7 +79,9 @@ fn match_args_from_caller_to_enzyme<'ll>( let enzyme_const = cx.create_metadata("enzyme_const".to_string()).unwrap(); let enzyme_out = cx.create_metadata("enzyme_out".to_string()).unwrap(); let enzyme_dup = cx.create_metadata("enzyme_dup".to_string()).unwrap(); + let enzyme_dupv = cx.create_metadata("enzyme_dupv".to_string()).unwrap(); let enzyme_dupnoneed = cx.create_metadata("enzyme_dupnoneed".to_string()).unwrap(); + let enzyme_dupnoneedv = cx.create_metadata("enzyme_dupnoneedv".to_string()).unwrap(); while activity_pos < inputs.len() { let diff_activity = inputs[activity_pos as usize]; @@ -90,13 +93,34 @@ fn match_args_from_caller_to_enzyme<'ll>( DiffActivity::Active => (enzyme_out, false), DiffActivity::ActiveOnly => (enzyme_out, false), DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::Dualv => (enzyme_dupv, true), DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::DualvOnly => (enzyme_dupnoneedv, true), DiffActivity::Duplicated => (enzyme_dup, true), DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), - DiffActivity::FakeActivitySize => (enzyme_const, false), + DiffActivity::FakeActivitySize(_) => (enzyme_const, false), }; let outer_arg = outer_args[outer_pos]; args.push(cx.get_metadata_value(activity)); + if matches!(diff_activity, DiffActivity::Dualv) { + let next_outer_arg = outer_args[outer_pos + 1]; + let elem_bytes_size: u64 = match inputs[activity_pos + 1] { + DiffActivity::FakeActivitySize(Some(s)) => s.into(), + _ => bug!("incorrect Dualv handling recognized."), + }; + // stride: sizeof(T) * n_elems. + // n_elems is the next integer. + // Now we multiply `4 * next_outer_arg` to get the stride. + let mul = unsafe { + llvm::LLVMBuildMul( + builder.llbuilder, + cx.get_const_i64(elem_bytes_size), + next_outer_arg, + UNNAMED, + ) + }; + args.push(mul); + } args.push(outer_arg); if duplicated { // We know that duplicated args by construction have a following argument, @@ -114,7 +138,7 @@ fn match_args_from_caller_to_enzyme<'ll>( } else { let next_activity = inputs[activity_pos + 1]; // We analyze the MIR types and add this dummy activity if we visit a slice. - next_activity == DiffActivity::FakeActivitySize + matches!(next_activity, DiffActivity::FakeActivitySize(_)) } }; if slice { @@ -125,7 +149,10 @@ fn match_args_from_caller_to_enzyme<'ll>( // int2 >= int1, which means the shadow vector is large enough to store the gradient. assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer); - for i in 0..(width as usize) { + let iterations = + if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize }; + + for i in 0..iterations { let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)]; let next_outer_ty2 = cx.val_ty(next_outer_arg2); assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer); @@ -136,7 +163,7 @@ fn match_args_from_caller_to_enzyme<'ll>( } args.push(cx.get_metadata_value(enzyme_const)); args.push(next_outer_arg); - outer_pos += 2 + 2 * width as usize; + outer_pos += 2 + 2 * iterations; activity_pos += 2; } else { // A duplicated pointer will have the following two outer_fn arguments: @@ -360,6 +387,7 @@ fn generate_enzyme_call<'ll>( let outer_args: Vec<&llvm::Value> = get_params(outer_fn); match_args_from_caller_to_enzyme( &cx, + &builder, attrs.width, &mut args, &attrs.input_activity, diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs index ebe0b258c1b6a..22d593b80b895 100644 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -2,7 +2,7 @@ use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::bug; use rustc_middle::mir::mono::MonoItem; -use rustc_middle::ty::{self, Instance, Ty, TyCtxt}; +use rustc_middle::ty::{self, Instance, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; use rustc_symbol_mangling::symbol_name_for_instance_in_crate; use tracing::{debug, trace}; @@ -22,23 +22,51 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec for (i, ty) in sig.inputs().iter().enumerate() { if let Some(inner_ty) = ty.builtin_deref(true) { if inner_ty.is_slice() { + // Now we need to figure out the size of each slice element in memory to allow + // safety checks and usability improvements in the backend. + let sty = match inner_ty.builtin_index() { + Some(sty) => sty, + None => { + panic!("slice element type unknown"); + } + }; + let pci = PseudoCanonicalInput { + typing_env: TypingEnv::fully_monomorphized(), + value: sty, + }; + + let layout = tcx.layout_of(pci); + let elem_size = match layout { + Ok(layout) => layout.size, + Err(_) => { + bug!("autodiff failed to compute slice element size"); + } + }; + let elem_size: u32 = elem_size.bytes() as u32; + // We know that the length will be passed as extra arg. if !da.is_empty() { // We are looking at a slice. The length of that slice will become an // extra integer on llvm level. Integers are always const. // However, if the slice get's duplicated, we want to know to later check the // size. So we mark the new size argument as FakeActivitySize. + // There is one FakeActivitySize per slice, so for convenience we store the + // slice element size in bytes in it. We will use the size in the backend. let activity = match da[i] { DiffActivity::DualOnly | DiffActivity::Dual + | DiffActivity::Dualv | DiffActivity::DuplicatedOnly - | DiffActivity::Duplicated => DiffActivity::FakeActivitySize, + | DiffActivity::Duplicated => { + DiffActivity::FakeActivitySize(Some(elem_size)) + } DiffActivity::Const => DiffActivity::Const, _ => bug!("unexpected activity for ptr/ref"), }; new_activities.push(activity); new_positions.push(i + 1); } + continue; } } From d7c0c328275fdcd5b6bba7f9ed7c97528103aca8 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Tue, 8 Apr 2025 20:59:17 -0400 Subject: [PATCH 2/2] passing test for dualv --- tests/codegen/autodiffv2.rs | 113 ++++++++++++++++++++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 tests/codegen/autodiffv2.rs diff --git a/tests/codegen/autodiffv2.rs b/tests/codegen/autodiffv2.rs new file mode 100644 index 0000000000000..a40d19d3be3a8 --- /dev/null +++ b/tests/codegen/autodiffv2.rs @@ -0,0 +1,113 @@ +//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme +// +// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many +// breakages. One benefit is that we match the IR generated by Enzyme only after running it +// through LLVM's O3 pipeline, which will remove most of the noise. +// However, our integration test could also be affected by changes in how rustc lowers MIR into +// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should +// reduce this test to only match the first lines and the ret instructions. +// +// The function tested here has 4 inputs and 5 outputs, so we could either call forward-mode +// autodiff 4 times, or reverse mode 5 times. Since a forward-mode call is usually faster than +// reverse mode, we prefer it here. This file also tests a new optimization (batch mode), which +// allows us to call forward-mode autodiff only once, and get all 5 outputs in a single call. +// +// We support 2 different batch modes. `d_square2` has the same interface as scalar forward-mode, +// but each shadow argument is `width` times larger (thus 16 and 20 elements here). +// `d_square3` instead takes `width` (4) shadow arguments, which are all the same size as the +// original function arguments. +// +// FIXME(autodiff): We currently can't test `d_square1` and `d_square3` in the same file, since they +// generate the same dummy functions which get merged by LLVM, breaking pieces of our pipeline which +// try to rewrite the dummy functions later. We should consider to change to pure declarations both +// in our frontend and in the llvm backend to avoid these issues. + +#![feature(autodiff)] + +use std::autodiff::autodiff; + +#[no_mangle] +//#[autodiff(d_square1, Forward, Dual, Dual)] +#[autodiff(d_square2, Forward, 4, Dualv, Dualv)] +#[autodiff(d_square3, Forward, 4, Dual, Dual)] +fn square(x: &[f32], y: &mut [f32]) { + assert!(x.len() >= 4); + assert!(y.len() >= 5); + y[0] = 4.3 * x[0] + 1.2 * x[1] + 3.4 * x[2] + 2.1 * x[3]; + y[1] = 2.3 * x[0] + 4.5 * x[1] + 1.7 * x[2] + 6.4 * x[3]; + y[2] = 1.1 * x[0] + 3.3 * x[1] + 2.5 * x[2] + 4.7 * x[3]; + y[3] = 5.2 * x[0] + 1.4 * x[1] + 2.6 * x[2] + 3.8 * x[3]; + y[4] = 1.0 * x[0] + 2.0 * x[1] + 3.0 * x[2] + 4.0 * x[3]; +} + +fn main() { + let x1 = std::hint::black_box(vec![0.0, 1.0, 2.0, 3.0]); + + let dx1 = std::hint::black_box(vec![1.0; 12]); + + let z1 = std::hint::black_box(vec![1.0, 0.0, 0.0, 0.0]); + let z2 = std::hint::black_box(vec![0.0, 1.0, 0.0, 0.0]); + let z3 = std::hint::black_box(vec![0.0, 0.0, 1.0, 0.0]); + let z4 = std::hint::black_box(vec![0.0, 0.0, 0.0, 1.0]); + + let z5 = std::hint::black_box(vec![ + 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0, + ]); + + let mut y1 = std::hint::black_box(vec![0.0; 5]); + let mut y2 = std::hint::black_box(vec![0.0; 5]); + let mut y3 = std::hint::black_box(vec![0.0; 5]); + let mut y4 = std::hint::black_box(vec![0.0; 5]); + + let mut y5 = std::hint::black_box(vec![0.0; 5]); + + let mut y6 = std::hint::black_box(vec![0.0; 5]); + + let mut dy1_1 = std::hint::black_box(vec![0.0; 5]); + let mut dy1_2 = std::hint::black_box(vec![0.0; 5]); + let mut dy1_3 = std::hint::black_box(vec![0.0; 5]); + let mut dy1_4 = std::hint::black_box(vec![0.0; 5]); + + let mut dy2 = std::hint::black_box(vec![0.0; 20]); + + let mut dy3_1 = std::hint::black_box(vec![0.0; 5]); + let mut dy3_2 = std::hint::black_box(vec![0.0; 5]); + let mut dy3_3 = std::hint::black_box(vec![0.0; 5]); + let mut dy3_4 = std::hint::black_box(vec![0.0; 5]); + + // scalar. + //d_square1(&x1, &z1, &mut y1, &mut dy1_1); + //d_square1(&x1, &z2, &mut y2, &mut dy1_2); + //d_square1(&x1, &z3, &mut y3, &mut dy1_3); + //d_square1(&x1, &z4, &mut y4, &mut dy1_4); + + // assert y1 == y2 == y3 == y4 + //for i in 0..5 { + // assert_eq!(y1[i], y2[i]); + // assert_eq!(y1[i], y3[i]); + // assert_eq!(y1[i], y4[i]); + //} + + // batch mode A) + d_square2(&x1, &z5, &mut y5, &mut dy2); + + // assert y1 == y2 == y3 == y4 == y5 + //for i in 0..5 { + // assert_eq!(y1[i], y5[i]); + //} + + // batch mode B) + d_square3(&x1, &z1, &z2, &z3, &z4, &mut y6, &mut dy3_1, &mut dy3_2, &mut dy3_3, &mut dy3_4); + for i in 0..5 { + assert_eq!(y5[i], y6[i]); + } + + for i in 0..5 { + assert_eq!(dy2[0..5][i], dy3_1[i]); + assert_eq!(dy2[5..10][i], dy3_2[i]); + assert_eq!(dy2[10..15][i], dy3_3[i]); + assert_eq!(dy2[15..20][i], dy3_4[i]); + } +}