diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 07ff9f3..608a1dc 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,13 +33,14 @@ jobs: unit-tests: name: ${{ matrix.basename }} - ${{ matrix.os }} - ${{ matrix.llvm[0] }} - ${{ matrix.features }} runs-on: ${{ matrix.os }} - continue-on-error: ${{ matrix.experimental }} + continue-on-error: ${{ matrix.experimental || matrix.llvm[0] != '17.0' }} strategy: matrix: llvm: - ["15.0", "15-0", "150"] - ["16.0", "16-0", "160"] - ["17.0", "17-0", "170"] + - ["18.1", "18-1", "181"] - ["19.1", "19-1", "191", "19"] - ["20.1", "20-1", "201", "20"] - ["21.1", "21-1", "211", "21"] @@ -62,26 +63,6 @@ jobs: - Tests include: - - toolchain: stable - os: ubuntu-latest - llvm: ["18.1", "18-1", "181"] - features: "rayon cranelift external_f64" - tests: true - clippy: false - rustdoc: false - experimental: true - basename: Tests - - - toolchain: stable - os: macos-latest - llvm: ["18.1", "18-1", "181"] - features: "rayon cranelift external_f64" - tests: true - clippy: false - rustdoc: false - experimental: true - basename: Tests - - toolchain: stable os: windows-latest llvm: "" diff --git a/diffsl/src/discretise/discrete_model.rs b/diffsl/src/discretise/discrete_model.rs index 7dc9d3c..d71565c 100644 --- a/diffsl/src/discretise/discrete_model.rs +++ b/diffsl/src/discretise/discrete_model.rs @@ -615,6 +615,12 @@ impl<'s> DiscreteModel<'s> { Self::check_match(reset, &ret.state, span, &mut env); } } + if ret.reset.is_some() && ret.stop.is_none() { + env.errs_mut().push(ValidationError::new( + "reset requires stop to also be defined".to_string(), + span_reset.flatten(), + )); + } let map_dep = |deps: &Vec| -> Vec<(usize, usize)> { deps.iter() @@ -1722,6 +1728,32 @@ mod tests { ); } + #[test] + fn test_reset_requires_stop() { + let text = " + u_i { + y = 1, + } + F_i { + y, + } + reset_i { + 2 * y, + } + "; + let model_ds = parse_ds_string(text).unwrap(); + let model = DiscreteModel::build("$name", &model_ds); + assert!( + model.is_err(), + "reset_i should require stop_i to also be defined" + ); + let errs = model.unwrap_err(); + assert!( + errs.has_error_contains("reset requires stop to also be defined"), + "expected missing stop validation error" + ); + } + #[test] fn test_no_out() { let text = " diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index af4c4cf..18c4b32 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -433,6 +433,150 @@ impl Compiler { }); } + #[allow(clippy::too_many_arguments)] + pub fn calc_stop_grad( + &self, + t: T, + yy: &[T], + dyy: &[T], + data: &[T], + ddata: &mut [T], + stop: &[T], + dstop: &mut [T], + ) { + if self.number_of_stop == 0 { + panic!("Model does not have a stop function"); + } + self.check_state_len(yy, "yy"); + self.check_state_len(dyy, "dyy"); + self.check_stop_len(stop, "stop"); + self.check_stop_len(dstop, "dstop"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self.jit_grad_functions.stop_grad)( + t, + yy.as_ptr(), + dyy.as_ptr(), + data.as_ptr(), + ddata.as_ptr() as *mut T, + stop.as_ptr(), + dstop.as_ptr() as *mut T, + i, + dim, + ) + }); + } + + #[allow(clippy::too_many_arguments)] + pub fn calc_stop_rgrad( + &self, + t: T, + yy: &[T], + dyy: &mut [T], + data: &[T], + ddata: &mut [T], + stop: &[T], + dstop: &mut [T], + ) { + if self.number_of_stop == 0 { + panic!("Model does not have a stop function"); + } + self.check_state_len(yy, "yy"); + self.check_state_len(dyy, "dyy"); + self.check_stop_len(stop, "stop"); + self.check_stop_len(dstop, "dstop"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self + .jit_grad_r_functions + .as_ref() + .expect("module does not support reverse autograd") + .stop_rgrad)( + t, + yy.as_ptr(), + dyy.as_ptr() as *mut T, + data.as_ptr(), + ddata.as_ptr() as *mut T, + stop.as_ptr(), + dstop.as_ptr() as *mut T, + i, + dim, + ) + }); + } + + pub fn calc_stop_sgrad( + &self, + t: T, + yy: &[T], + data: &[T], + ddata: &mut [T], + stop: &[T], + dstop: &mut [T], + ) { + if self.number_of_stop == 0 { + panic!("Model does not have a stop function"); + } + self.check_state_len(yy, "yy"); + self.check_stop_len(stop, "stop"); + self.check_stop_len(dstop, "dstop"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self + .jit_sens_grad_functions + .as_ref() + .expect("module does not support sens autograd") + .stop_sgrad)( + t, + yy.as_ptr(), + data.as_ptr(), + ddata.as_ptr() as *mut T, + stop.as_ptr(), + dstop.as_ptr() as *mut T, + i, + dim, + ) + }); + } + + pub fn calc_stop_srgrad( + &self, + t: T, + yy: &[T], + data: &[T], + ddata: &mut [T], + stop: &[T], + dstop: &mut [T], + ) { + if self.number_of_stop == 0 { + panic!("Model does not have a stop function"); + } + self.check_state_len(yy, "yy"); + self.check_stop_len(stop, "stop"); + self.check_stop_len(dstop, "dstop"); + self.check_data_len(data, "data"); + self.check_data_len(ddata, "ddata"); + self.with_threading(|i, dim| unsafe { + (self + .jit_sens_rev_grad_functions + .as_ref() + .expect("module does not support sens autograd") + .stop_rgrad)( + t, + yy.as_ptr(), + data.as_ptr(), + ddata.as_ptr() as *mut T, + stop.as_ptr(), + dstop.as_ptr() as *mut T, + i, + dim, + ) + }); + } + pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T]) { if reset.is_empty() { return; @@ -1197,6 +1341,7 @@ mod tests { } generate_tests!(test_stop); + generate_tests!(test_stop_gradients); generate_tests!(test_reset); generate_tests!(test_reset_gradients); generate_tests!(test_reset_without_reset_tensor_is_noop); @@ -1252,6 +1397,142 @@ mod tests { assert_eq!(stop.len(), 1); } + #[allow(dead_code)] + fn test_stop_gradients() { + let full_text = " + in { + a = 1, + } + u_i { + y = a, + z = 2, + } + F_i { + y, + z, + } + stop_i { + 2 * y + a, + z + a, + } + out_i { + y, + z, + } + "; + let model = parse_ds_string(full_text).unwrap(); + let discrete_model = DiscreteModel::build("$name", &model).unwrap(); + let compiler = Compiler::::from_discrete_model( + &discrete_model, + Default::default(), + Some(full_text), + ) + .unwrap(); + + let mut data = compiler.get_new_data(); + let inputs = vec![T::from_f64(3.0).unwrap()]; + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); + + let mut yy = vec![T::zero(), T::zero()]; + compiler.set_u0(yy.as_mut_slice(), data.as_mut_slice()); + + let mut stop = vec![T::zero(), T::zero()]; + compiler.calc_stop( + T::zero(), + yy.as_slice(), + data.as_mut_slice(), + stop.as_mut_slice(), + ); + assert_relative_eq!(stop[0], T::from_f64(9.0).unwrap()); + assert_relative_eq!(stop[1], T::from_f64(5.0).unwrap()); + + let mut ddata = compiler.get_new_data(); + let dinputs = vec![T::one()]; + compiler.set_inputs_grad( + inputs.as_slice(), + dinputs.as_slice(), + data.as_slice(), + ddata.as_mut_slice(), + 0, + ); + let dyy = vec![T::one(), T::zero()]; + let mut dstop = vec![T::zero(), T::zero()]; + compiler.calc_stop_grad( + T::zero(), + yy.as_slice(), + dyy.as_slice(), + data.as_slice(), + ddata.as_mut_slice(), + stop.as_slice(), + dstop.as_mut_slice(), + ); + assert_relative_eq!(dstop[0], T::from_f64(3.0).unwrap()); + assert_relative_eq!(dstop[1], T::from_f64(1.0).unwrap()); + + if compiler.supports_reverse_autodiff() { + let mut dyy_rev = vec![T::zero(), T::zero()]; + let mut ddata_rev = compiler.get_new_data(); + let mut dstop_rev = vec![T::one(), T::one()]; + compiler.calc_stop_rgrad( + T::zero(), + yy.as_slice(), + dyy_rev.as_mut_slice(), + data.as_slice(), + ddata_rev.as_mut_slice(), + stop.as_slice(), + dstop_rev.as_mut_slice(), + ); + assert_relative_eq!(dyy_rev[0], T::from_f64(2.0).unwrap()); + assert_relative_eq!(dyy_rev[1], T::one()); + + let mut dinputs_rev = vec![T::zero(); inputs.len()]; + compiler.set_inputs_rgrad( + inputs.as_slice(), + dinputs_rev.as_mut_slice(), + data.as_slice(), + ddata_rev.as_mut_slice(), + 0, + ); + assert_relative_eq!(dinputs_rev[0], T::from_f64(2.0).unwrap()); + + let mut ddata_s = compiler.get_new_data(); + let dinputs_s = vec![T::one(); inputs.len()]; + compiler.set_inputs(dinputs_s.as_slice(), ddata_s.as_mut_slice(), 0); + let mut dstop_s = vec![T::zero(), T::zero()]; + compiler.calc_stop_sgrad( + T::zero(), + yy.as_slice(), + data.as_slice(), + ddata_s.as_mut_slice(), + stop.as_slice(), + dstop_s.as_mut_slice(), + ); + assert_relative_eq!(dstop_s[0], T::one()); + assert_relative_eq!(dstop_s[1], T::one()); + + let mut ddata_sr = compiler.get_new_data(); + let mut dstop_sr = vec![T::one(), T::one()]; + compiler.calc_stop_srgrad( + T::zero(), + yy.as_slice(), + data.as_slice(), + ddata_sr.as_mut_slice(), + stop.as_slice(), + dstop_sr.as_mut_slice(), + ); + + let mut dinputs_sr = vec![T::zero(); inputs.len()]; + compiler.set_inputs_rgrad( + inputs.as_slice(), + dinputs_sr.as_mut_slice(), + data.as_slice(), + ddata_sr.as_mut_slice(), + 0, + ); + assert_relative_eq!(dinputs_sr[0], T::from_f64(2.0).unwrap()); + } + } + #[allow(dead_code)] fn test_reset() { let full_text = " @@ -1332,6 +1613,10 @@ mod tests { 2 * y + a, z + a, } + stop_i { + y - 0.5, + z - 1, + } out_i { y, z, diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index 345f230..2a6405a 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -347,6 +347,66 @@ impl CraneliftModule { self.declare_function("reset_grad") } + fn compile_calc_stop_grad( + &mut self, + _func_id: &FuncId, + model: &DiscreteModel, + ) -> Result { + let arg_types = &[ + self.real_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.real_ptr_type, + self.int_type, + self.int_type, + ]; + let arg_names = &[ + "t", + "u", + "du", + "data", + "ddata", + "root", + "droot", + "threadId", + "threadDim", + ]; + { + let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); + + if let Some(stop) = model.stop() { + let mut nbarrier = 0; + for tensor in model.time_dep_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + + for tensor in model.state_dep_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + + for tensor in model.state_dep_post_f_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + + let droot_ptr = *codegen.variables.get("droot").unwrap(); + codegen.jit_compile_tensor(stop, Some(droot_ptr), true)?; + } + + codegen.builder.ins().return_(&[]); + codegen.builder.finalize(); + } + self.declare_function("calc_stop_grad") + } + fn compile_set_inputs_grad( &mut self, _func_id: &FuncId, @@ -557,6 +617,7 @@ impl CraneliftModule { let _set_u0_grad = ret.compile_set_u0_grad(&set_u0, model)?; let _rhs_grad = ret.compile_rhs_grad(&rhs, model)?; let _reset_grad = ret.compile_reset_grad(&reset, model)?; + let _calc_stop_grad = ret.compile_calc_stop_grad(&_calc_stop, model)?; let _calc_out_grad = ret.compile_calc_out_grad(&calc_out, model)?; let _set_inputs_grad = ret.compile_set_inputs_grad(&set_inputs, model)?; Ok(ret) diff --git a/diffsl/src/execution/external/mod.rs b/diffsl/src/execution/external/mod.rs index 30d0c6c..d43d617 100644 --- a/diffsl/src/execution/external/mod.rs +++ b/diffsl/src/execution/external/mod.rs @@ -246,6 +246,52 @@ macro_rules! define_symbol_module { thread_id: UIntType, thread_dim: UIntType, ); + #[link_name = "calc_stop_grad"] + pub fn calc_stop_grad( + time: $ty, + u: *const $ty, + du: *const $ty, + data: *const $ty, + ddata: *mut $ty, + root: *const $ty, + droot: *mut $ty, + thread_id: UIntType, + thread_dim: UIntType, + ); + #[link_name = "calc_stop_rgrad"] + pub fn calc_stop_rgrad( + time: $ty, + u: *const $ty, + du: *mut $ty, + data: *const $ty, + ddata: *mut $ty, + root: *const $ty, + droot: *mut $ty, + thread_id: UIntType, + thread_dim: UIntType, + ); + #[link_name = "calc_stop_sgrad"] + pub fn calc_stop_sgrad( + time: $ty, + u: *const $ty, + data: *const $ty, + ddata: *mut $ty, + root: *const $ty, + droot: *mut $ty, + thread_id: UIntType, + thread_dim: UIntType, + ); + #[link_name = "calc_stop_srgrad"] + pub fn calc_stop_srgrad( + time: $ty, + u: *const $ty, + data: *const $ty, + ddata: *mut $ty, + root: *const $ty, + droot: *mut $ty, + thread_id: UIntType, + thread_dim: UIntType, + ); #[link_name = "set_id"] pub fn set_id(id: *mut $ty); #[link_name = "get_dims"] @@ -360,6 +406,10 @@ impl_extern_symbols!(f64, f64_symbols, { "calc_out_sgrad" => calc_out_sgrad, "calc_out_srgrad" => calc_out_srgrad, "calc_stop" => calc_stop, + "calc_stop_grad" => calc_stop_grad, + "calc_stop_rgrad" => calc_stop_rgrad, + "calc_stop_sgrad" => calc_stop_sgrad, + "calc_stop_srgrad" => calc_stop_srgrad, "set_id" => set_id, "get_dims" => get_dims, "set_inputs" => set_inputs, @@ -394,6 +444,10 @@ impl_extern_symbols!(f32, f32_symbols, { "calc_out_sgrad" => calc_out_sgrad, "calc_out_srgrad" => calc_out_srgrad, "calc_stop" => calc_stop, + "calc_stop_grad" => calc_stop_grad, + "calc_stop_rgrad" => calc_stop_rgrad, + "calc_stop_sgrad" => calc_stop_sgrad, + "calc_stop_srgrad" => calc_stop_srgrad, "set_id" => set_id, "get_dims" => get_dims, "set_inputs" => set_inputs, diff --git a/diffsl/src/execution/interface.rs b/diffsl/src/execution/interface.rs index 25dbfb5..7e17ab3 100644 --- a/diffsl/src/execution/interface.rs +++ b/diffsl/src/execution/interface.rs @@ -15,6 +15,48 @@ pub type StopFunc = unsafe extern "C" fn( thread_id: UIntType, thread_dim: UIntType, ); +pub type StopGradFunc = unsafe extern "C" fn( + time: T, + u: *const T, + du: *const T, + data: *const T, + ddata: *mut T, + root: *const T, + droot: *mut T, + thread_id: UIntType, + thread_dim: UIntType, +); +pub type StopRevGradFunc = unsafe extern "C" fn( + time: T, + u: *const T, + du: *mut T, + data: *const T, + ddata: *mut T, + root: *const T, + droot: *mut T, + thread_id: UIntType, + thread_dim: UIntType, +); +pub type StopSensGradFunc = unsafe extern "C" fn( + time: T, + u: *const T, + data: *const T, + ddata: *mut T, + root: *const T, + droot: *mut T, + thread_id: UIntType, + thread_dim: UIntType, +); +pub type StopSensRevGradFunc = unsafe extern "C" fn( + time: T, + u: *const T, + data: *const T, + ddata: *mut T, + root: *const T, + droot: *mut T, + thread_id: UIntType, + thread_dim: UIntType, +); pub type ResetFunc = unsafe extern "C" fn( time: T, u: *const T, @@ -321,6 +363,7 @@ impl JitFunctions { pub(crate) struct JitGradFunctions { pub(crate) set_u0_grad: U0GradFunc, + pub(crate) stop_grad: StopGradFunc, pub(crate) reset_grad: ResetGradFunc, pub(crate) rhs_grad: RhsGradFunc, pub(crate) calc_out_grad: CalcOutGradFunc, @@ -332,6 +375,7 @@ impl JitGradFunctions { // check if all required symbols are present let required_symbols = [ "set_u0_grad", + "calc_stop_grad", "reset_grad", "rhs_grad", "calc_out_grad", @@ -344,6 +388,9 @@ impl JitGradFunctions { } let set_u0_grad = unsafe { std::mem::transmute::<*const u8, U0GradFunc>(symbol_map["set_u0_grad"]) }; + let stop_grad = unsafe { + std::mem::transmute::<*const u8, StopGradFunc>(symbol_map["calc_stop_grad"]) + }; let reset_grad = unsafe { std::mem::transmute::<*const u8, ResetGradFunc>(symbol_map["reset_grad"]) }; let rhs_grad = @@ -357,6 +404,7 @@ impl JitGradFunctions { Ok(Self { set_u0_grad, + stop_grad, reset_grad, rhs_grad, calc_out_grad, @@ -367,6 +415,7 @@ impl JitGradFunctions { pub(crate) struct JitGradRFunctions { pub(crate) set_u0_rgrad: U0RevGradFunc, + pub(crate) stop_rgrad: StopRevGradFunc, pub(crate) reset_rgrad: ResetRevGradFunc, pub(crate) rhs_rgrad: RhsRevGradFunc, pub(crate) mass_rgrad: MassRevGradFunc, @@ -378,6 +427,7 @@ impl JitGradRFunctions { pub(crate) fn new(symbol_map: &HashMap) -> Result { let required_symbols = [ "set_u0_rgrad", + "calc_stop_rgrad", "reset_rgrad", "rhs_rgrad", "mass_rgrad", @@ -392,6 +442,9 @@ impl JitGradRFunctions { let set_u0_rgrad = unsafe { std::mem::transmute::<*const u8, U0RevGradFunc>(symbol_map["set_u0_rgrad"]) }; + let stop_rgrad = unsafe { + std::mem::transmute::<*const u8, StopRevGradFunc>(symbol_map["calc_stop_rgrad"]) + }; let reset_rgrad = unsafe { std::mem::transmute::<*const u8, ResetRevGradFunc>(symbol_map["reset_rgrad"]) }; @@ -411,6 +464,7 @@ impl JitGradRFunctions { Ok(Self { set_u0_rgrad, + stop_rgrad, reset_rgrad, rhs_rgrad, mass_rgrad, @@ -422,6 +476,7 @@ impl JitGradRFunctions { pub(crate) struct JitSensGradFunctions { pub(crate) set_u0_sgrad: U0SensGradFunc, + pub(crate) stop_sgrad: StopSensGradFunc, pub(crate) reset_sgrad: ResetSensGradFunc, pub(crate) rhs_sgrad: RhsSensGradFunc, pub(crate) calc_out_sgrad: CalcOutSensGradFunc, @@ -429,7 +484,13 @@ pub(crate) struct JitSensGradFunctions { impl JitSensGradFunctions { pub(crate) fn new(symbol_map: &HashMap) -> Result { - let required_symbols = ["rhs_sgrad", "calc_out_sgrad", "set_u0_sgrad", "reset_sgrad"]; + let required_symbols = [ + "rhs_sgrad", + "calc_out_sgrad", + "set_u0_sgrad", + "calc_stop_sgrad", + "reset_sgrad", + ]; for symbol in &required_symbols { if !symbol_map.contains_key(*symbol) { return Err(anyhow!("Missing required symbol: {}", symbol)); @@ -444,6 +505,9 @@ impl JitSensGradFunctions { let set_u0_sgrad = unsafe { std::mem::transmute::<*const u8, U0SensGradFunc>(symbol_map["set_u0_sgrad"]) }; + let stop_sgrad = unsafe { + std::mem::transmute::<*const u8, StopSensGradFunc>(symbol_map["calc_stop_sgrad"]) + }; let reset_sgrad = unsafe { std::mem::transmute::<*const u8, ResetSensGradFunc>(symbol_map["reset_sgrad"]) }; @@ -452,12 +516,14 @@ impl JitSensGradFunctions { rhs_sgrad, calc_out_sgrad, set_u0_sgrad, + stop_sgrad, reset_sgrad, }) } } pub(crate) struct JitSensRevGradFunctions { + pub(crate) stop_rgrad: StopSensRevGradFunc, pub(crate) reset_rgrad: ResetSensRevGradFunc, pub(crate) rhs_rgrad: RhsSensRevGradFunc, pub(crate) calc_out_rgrad: CalcOutSensRevGradFunc, @@ -465,7 +531,12 @@ pub(crate) struct JitSensRevGradFunctions { impl JitSensRevGradFunctions { pub(crate) fn new(symbol_map: &HashMap) -> Result { - let required_symbols = ["rhs_srgrad", "calc_out_srgrad", "reset_srgrad"]; + let required_symbols = [ + "rhs_srgrad", + "calc_out_srgrad", + "calc_stop_srgrad", + "reset_srgrad", + ]; for symbol in &required_symbols { if !symbol_map.contains_key(*symbol) { return Err(anyhow!("Missing required symbol: {}", symbol)); @@ -474,6 +545,9 @@ impl JitSensRevGradFunctions { let reset_rgrad = unsafe { std::mem::transmute::<*const u8, ResetSensRevGradFunc>(symbol_map["reset_srgrad"]) }; + let stop_rgrad = unsafe { + std::mem::transmute::<*const u8, StopSensRevGradFunc>(symbol_map["calc_stop_srgrad"]) + }; let rhs_rgrad = unsafe { std::mem::transmute::<*const u8, RhsSensRevGradFunc>(symbol_map["rhs_srgrad"]) }; @@ -484,6 +558,7 @@ impl JitSensRevGradFunctions { }; Ok(Self { + stop_rgrad, reset_rgrad, rhs_rgrad, calc_out_rgrad, diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 903bdc1..24fb3bd 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -289,7 +289,7 @@ impl CodegenModuleCompile for LlvmModule { let mut module = Self::new(triple, model, threaded, real_type, options.debug)?; let set_u0 = module.codegen_mut().compile_set_u0(model, code)?; - let _calc_stop = module.codegen_mut().compile_calc_stop(model, code)?; + let calc_stop = module.codegen_mut().compile_calc_stop(model, code)?; let reset = module.codegen_mut().compile_reset(model, code)?; let rhs = module.codegen_mut().compile_rhs(model, false, code)?; let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?; @@ -361,6 +361,20 @@ impl CodegenModuleCompile for LlvmModule { "reset_grad", )?; + module.codegen_mut().compile_gradient( + calc_stop, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::Forward, + "calc_stop_grad", + )?; + module.codegen_mut().compile_gradient( calc_out, &[ @@ -437,6 +451,19 @@ impl CodegenModuleCompile for LlvmModule { CompileMode::Reverse, "reset_rgrad", )?; + module.codegen_mut().compile_gradient( + calc_stop, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::Reverse, + "calc_stop_rgrad", + )?; module.codegen_mut().compile_gradient( calc_out, &[ @@ -490,6 +517,20 @@ impl CodegenModuleCompile for LlvmModule { "reset_sgrad", )?; + module.codegen_mut().compile_gradient( + calc_stop, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::ForwardSens, + "calc_stop_sgrad", + )?; + module.codegen_mut().compile_gradient( set_u0, &[ @@ -557,6 +598,20 @@ impl CodegenModuleCompile for LlvmModule { "reset_srgrad", )?; + module.codegen_mut().compile_gradient( + calc_stop, + &[ + CompileGradientArgType::Const, + CompileGradientArgType::Const, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::DupNoNeed, + CompileGradientArgType::Const, + CompileGradientArgType::Const, + ], + CompileMode::ReverseSens, + "calc_stop_srgrad", + )?; + module.post_autodiff_optimisation()?; Ok(module) } diff --git a/diffsl/tests/support/external_test_macros.rs b/diffsl/tests/support/external_test_macros.rs index 2b68baf..24ece34 100644 --- a/diffsl/tests/support/external_test_macros.rs +++ b/diffsl/tests/support/external_test_macros.rs @@ -296,6 +296,80 @@ macro_rules! define_external_test { *root = *u - (0.5 as $ty); } + #[no_mangle] + pub unsafe extern "C" fn calc_stop_grad( + _time: $ty, + _u: *const $ty, + du: *const $ty, + _data: *const $ty, + ddata: *mut $ty, + _root: *const $ty, + droot: *mut $ty, + _thread_id: u32, + _thread_dim: u32, + ) { + if du.is_null() || ddata.is_null() || droot.is_null() { + return; + } + *droot = *du; + *ddata = 0.0 as $ty; + } + + #[no_mangle] + pub unsafe extern "C" fn calc_stop_rgrad( + _time: $ty, + _u: *const $ty, + du: *mut $ty, + _data: *const $ty, + ddata: *mut $ty, + _root: *const $ty, + droot: *mut $ty, + _thread_id: u32, + _thread_dim: u32, + ) { + if du.is_null() || ddata.is_null() || droot.is_null() { + return; + } + *du += *droot; + *ddata += 0.0 as $ty; + } + + #[no_mangle] + pub unsafe extern "C" fn calc_stop_sgrad( + _time: $ty, + _u: *const $ty, + _data: *const $ty, + ddata: *mut $ty, + _root: *const $ty, + droot: *mut $ty, + _thread_id: u32, + _thread_dim: u32, + ) { + if ddata.is_null() || droot.is_null() { + return; + } + *droot = 0.0 as $ty; + *ddata = 0.0 as $ty; + } + + #[no_mangle] + pub unsafe extern "C" fn calc_stop_srgrad( + _time: $ty, + _u: *const $ty, + _data: *const $ty, + ddata: *mut $ty, + _root: *const $ty, + droot: *mut $ty, + _thread_id: u32, + _thread_dim: u32, + ) { + if ddata.is_null() || droot.is_null() { + return; + } + *droot = 0.0 as $ty; + *ddata = 0.0 as $ty; + } + #[no_mangle] pub unsafe extern "C" fn reset( _time: $ty, @@ -509,6 +583,58 @@ macro_rules! define_external_test { compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop); assert_eq!(stop[0], 0.5 as $ty); + let du_stop = vec![1.0 as $ty; n_states]; + let mut ddata_stop = vec![-5.15 as $ty; n_data]; + let mut dstop = vec![-5.25 as $ty; n_stop]; + compiler.calc_stop_grad( + 0.0 as $ty, + &u, + &du_stop, + &data, + &mut ddata_stop, + &stop, + &mut dstop, + ); + assert_eq!(dstop[0], 1.0 as $ty); + + let mut du_stop_rev = vec![-5.35 as $ty; n_states]; + let mut ddata_stop_rev = vec![-5.45 as $ty; n_data]; + let mut dstop_rev = vec![1.0 as $ty; n_stop]; + compiler.calc_stop_rgrad( + 0.0 as $ty, + &u, + &mut du_stop_rev, + &data, + &mut ddata_stop_rev, + &stop, + &mut dstop_rev, + ); + assert!((du_stop_rev[0] - (-4.35 as $ty)).abs() < (1e-6 as $ty)); + + let mut ddata_stop_s = vec![-5.55 as $ty; n_data]; + let mut dstop_s = vec![-5.65 as $ty; n_stop]; + compiler.calc_stop_sgrad( + 0.0 as $ty, + &u, + &data, + &mut ddata_stop_s, + &stop, + &mut dstop_s, + ); + assert_eq!(dstop_s[0], 0.0 as $ty); + + let mut ddata_stop_sr = vec![-5.75 as $ty; n_data]; + let mut dstop_sr = vec![1.0 as $ty; n_stop]; + compiler.calc_stop_srgrad( + 0.0 as $ty, + &u, + &data, + &mut ddata_stop_sr, + &stop, + &mut dstop_sr, + ); + assert_eq!(dstop_sr[0], 0.0 as $ty); + let mut reset = vec![-5.5 as $ty; n_states]; compiler.reset(0.0 as $ty, &u, &mut data, &mut reset); assert_eq!(reset[0], 2.0 as $ty);