-
Notifications
You must be signed in to change notification settings - Fork 13.3k
autodiff on generic functions fails #140032
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Unfortunately I can't find the old implementation atm, but it was proc-macro based, which can't be used in rustc. We instead should teach our rustc_builtin autodiff macro to handle generics. Right now, if you run #[rustc_autodiff]
#[inline(never)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
fn d_square(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe {
asm!("NOP", options(pure, nomem));
};
::core::hint::black_box(square(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(square(x))
} The source function square is fine, but d_square has to change. We need to copy the generic bounds. We might also want to specify the bounds in the body (the #[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
fn d_square<T: std::ops::Mul<Output = T> + Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe {
asm!("NOP", options(pure, nomem));
};
::core::hint::black_box::<T>(square(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box::<T>(square(x))
} |
@rustbot claim |
the two-fold design with the autodiff macro on the frontend, and the rustc_autodiff attribute in the backend means that we also can just sidestep the macro for quick experiments. I just verified that you really just need to copy the generic bounds, so it should be an easy fix! The code below runs, I just used #![feature(rustc_attrs)]
#![feature(panic_internals)]
#![feature(prelude_import)]
#![feature(autodiff)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
use std::autodiff::autodiff;
use std::arch::asm;
#[rustc_autodiff]
#[inline(never)]
fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
*x * *x
}
#[rustc_autodiff(Reverse, 1, Duplicated, Active)]
#[inline(never)]
fn d_square<T: std::ops::Mul<Output = T> + Copy>(x: &T, dx_0: &mut T, dret: T) -> T {
unsafe {
asm!("NOP", options(nomem));
};
::core::hint::black_box(square(x));
::core::hint::black_box((dx_0, dret));
::core::hint::black_box(square(x))
}
fn main() {
let xf32: f32 = 3.0;
let xf64: f64 = 3.0;
let outputf32 = square::<f32>(&xf32);
let outputf64 = square::<f64>(&xf64);
match (&9.0, &outputf32) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(
kind,
&*left_val,
&*right_val,
::core::option::Option::None,
);
}
}
};
match (&9.0, &outputf64) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(
kind,
&*left_val,
&*right_val,
::core::option::Option::None,
);
}
}
};
let mut df_dxf32: f32 = 0.0;
let mut df_dxf64: f64 = 0.0;
let output_f32 = d_square::<f32>(&xf32, &mut df_dxf32, 1.0);
let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
match (&outputf32, &output_f32) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(
kind,
&*left_val,
&*right_val,
::core::option::Option::None,
);
}
}
};
match (&outputf64, &output_f64) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(
kind,
&*left_val,
&*right_val,
::core::option::Option::None,
);
}
}
};
match (&6.0, &df_dxf32) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(
kind,
&*left_val,
&*right_val,
::core::option::Option::None,
);
}
}
};
match (&6.0, &df_dxf64) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(
kind,
&*left_val,
&*right_val,
::core::option::Option::None,
);
}
}
};
} |
I tried this code:
I expected to see this happen: works.
Instead, this happened:
Meta
rustc --version --verbose
:Backtrace
Solution: TBA
cc @haenoe
The text was updated successfully, but these errors were encountered: