Skip to content

autodiff on generic functions fails #140032

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
ZuseZ4 opened this issue Apr 19, 2025 · 3 comments · May be fixed by #140049
Open

autodiff on generic functions fails #140032

ZuseZ4 opened this issue Apr 19, 2025 · 3 comments · May be fixed by #140049
Assignees
Labels
C-bug Category: This is a bug. E-medium Call for participation: Medium difficulty. Experience needed to fix: Intermediate. F-autodiff `#![feature(autodiff)]`

Comments

@ZuseZ4
Copy link
Member

ZuseZ4 commented Apr 19, 2025

I tried this code:

#![feature(autodiff)]

use std::autodiff::autodiff;

    #[autodiff(d_square, Reverse, Duplicated, Active)]
    fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
        *x * *x
    }

    fn main() {
        let xf32: f32 = 3.0;
        let xf64: f64 = 3.0;
        let outputf32 = square::<f32>(&xf32);
        let outputf64 = square::<f64>(&xf64);
        assert_eq!(9.0, outputf32);
        assert_eq!(9.0, outputf64);

        let mut df_dxf32: f32 = 0.0;
        let mut df_dxf64: f64 = 0.0;
        let output_f32 = d_square::<f32>(&xf32, &mut df_dxf32, 1.0);
        let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
        assert_eq!(outputf32, output_f32);
        assert_eq!(outputf64, output_f64);
        assert_eq!(6.0, df_dxf32);
        assert_eq!(6.0, df_dxf64);
    }

I expected to see this happen: works.

Instead, this happened:

error[E0412]: cannot find type `T` in this scope
 --> src/main.rs:6:56
  |
6 |     fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
  |                                                        ^ not found in this scope
  |
help: you might be missing a type parameter
  |
1 | <T>#![feature(autodiff)]
  | +++

Meta

rustc --version --verbose:

build from source
Backtrace

<backtrace>

Solution: TBA

cc @haenoe

@ZuseZ4 ZuseZ4 added C-bug Category: This is a bug. E-medium Call for participation: Medium difficulty. Experience needed to fix: Intermediate. F-autodiff `#![feature(autodiff)]` labels Apr 19, 2025
@rustbot rustbot added the needs-triage This issue may need triage. Remove it if it has been sufficiently triaged. label Apr 19, 2025
@ZuseZ4 ZuseZ4 removed the needs-triage This issue may need triage. Remove it if it has been sufficiently triaged. label Apr 19, 2025
@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Apr 19, 2025

Unfortunately I can't find the old implementation atm, but it was proc-macro based, which can't be used in rustc. We instead should teach our rustc_builtin autodiff macro to handle generics. Right now, if you run cargo +enzyme expand on the code above, you'll get this output:

#[rustc_autodiff]
#[inline(never)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
    *x * *x
}
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
fn d_square(x: &T, dx_0: &mut T, dret: T) -> T {
    unsafe {
        asm!("NOP", options(pure, nomem));
    };
    ::core::hint::black_box(square(x));
    ::core::hint::black_box((dx_0, dret));
    ::core::hint::black_box(square(x))
}

The source function square is fine, but d_square has to change. We need to copy the generic bounds. We might also want to specify the bounds in the body (the ::<T>), but I'm not sure if that's ever needed, so we could also skip it in the beginning.

#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
fn d_square<T: std::ops::Mul<Output = T> + Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
    unsafe {
        asm!("NOP", options(pure, nomem));
    };
    ::core::hint::black_box::<T>(square(x));
    ::core::hint::black_box((dx_0, dret));
    ::core::hint::black_box::<T>(square(x))
}

@haenoe
Copy link
Contributor

haenoe commented Apr 19, 2025

@rustbot claim

@ZuseZ4
Copy link
Member Author

ZuseZ4 commented Apr 19, 2025

the two-fold design with the autodiff macro on the frontend, and the rustc_autodiff attribute in the backend means that we also can just sidestep the macro for quick experiments. I just verified that you really just need to copy the generic bounds, so it should be an easy fix! The code below runs, I just used cargo expand, allowed some attributes, simplified the inline asm (which I guess has always been bugy) and then I manually copied the generic bounds:

#![feature(rustc_attrs)]
#![feature(panic_internals)]
#![feature(prelude_import)]
#![feature(autodiff)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
use std::autodiff::autodiff;
use std::arch::asm;

#[rustc_autodiff]
#[inline(never)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
    *x * *x
}
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
fn d_square<T: std::ops::Mul<Output = T> + Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
    unsafe {
        asm!("NOP", options(nomem));
    };
    ::core::hint::black_box(square(x));
    ::core::hint::black_box((dx_0, dret));
    ::core::hint::black_box(square(x))
}
fn main() {
    let xf32: f32 = 3.0;
    let xf64: f64 = 3.0;
    let outputf32 = square::<f32>(&xf32);
    let outputf64 = square::<f64>(&xf64);
    match (&9.0, &outputf32) {
        (left_val, right_val) => {
            if !(*left_val == *right_val) {
                let kind = ::core::panicking::AssertKind::Eq;
                ::core::panicking::assert_failed(
                    kind,
                    &*left_val,
                    &*right_val,
                    ::core::option::Option::None,
                );
            }
        }
    };
    match (&9.0, &outputf64) {
        (left_val, right_val) => {
            if !(*left_val == *right_val) {
                let kind = ::core::panicking::AssertKind::Eq;
                ::core::panicking::assert_failed(
                    kind,
                    &*left_val,
                    &*right_val,
                    ::core::option::Option::None,
                );
            }
        }
    };
    let mut df_dxf32: f32 = 0.0;
    let mut df_dxf64: f64 = 0.0;
    let output_f32 = d_square::<f32>(&xf32, &mut df_dxf32, 1.0);
    let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
    match (&outputf32, &output_f32) {
        (left_val, right_val) => {
            if !(*left_val == *right_val) {
                let kind = ::core::panicking::AssertKind::Eq;
                ::core::panicking::assert_failed(
                    kind,
                    &*left_val,
                    &*right_val,
                    ::core::option::Option::None,
                );
            }
        }
    };
    match (&outputf64, &output_f64) {
        (left_val, right_val) => {
            if !(*left_val == *right_val) {
                let kind = ::core::panicking::AssertKind::Eq;
                ::core::panicking::assert_failed(
                    kind,
                    &*left_val,
                    &*right_val,
                    ::core::option::Option::None,
                );
            }
        }
    };
    match (&6.0, &df_dxf32) {
        (left_val, right_val) => {
            if !(*left_val == *right_val) {
                let kind = ::core::panicking::AssertKind::Eq;
                ::core::panicking::assert_failed(
                    kind,
                    &*left_val,
                    &*right_val,
                    ::core::option::Option::None,
                );
            }
        }
    };
    match (&6.0, &df_dxf64) {
        (left_val, right_val) => {
            if !(*left_val == *right_val) {
                let kind = ::core::panicking::AssertKind::Eq;
                ::core::panicking::assert_failed(
                    kind,
                    &*left_val,
                    &*right_val,
                    ::core::option::Option::None,
                );
            }
        }
    };
}

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
C-bug Category: This is a bug. E-medium Call for participation: Medium difficulty. Experience needed to fix: Intermediate. F-autodiff `#![feature(autodiff)]`
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants