From a0461f7cfc75ed06e4d47158ee0c22f6b57c15bd Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sun, 15 Mar 2026 20:16:23 +0000 Subject: [PATCH 1/4] draft out changes --- diffsl/src/discretise/discrete_model.rs | 16 +++ diffsl/src/execution/compiler.rs | 71 +++++++--- diffsl/src/execution/cranelift/codegen.rs | 165 +++++++++++++--------- diffsl/src/execution/data_layout.rs | 5 + diffsl/src/execution/external/mod.rs | 26 +++- diffsl/src/execution/interface.rs | 35 ++++- diffsl/src/execution/llvm/codegen.rs | 3 + 7 files changed, 230 insertions(+), 91 deletions(-) diff --git a/diffsl/src/discretise/discrete_model.rs b/diffsl/src/discretise/discrete_model.rs index 7dc9d3c..49328a7 100644 --- a/diffsl/src/discretise/discrete_model.rs +++ b/diffsl/src/discretise/discrete_model.rs @@ -35,6 +35,7 @@ pub struct DiscreteModel<'s> { out: Option>, constant_defns: Vec>, input_dep_defns: Vec>, + model_dep_defns: Vec>, time_dep_defns: Vec>, state_dep_defns: Vec>, state_dep_post_f_defns: Vec>, @@ -66,6 +67,9 @@ impl fmt::Display for DiscreteModel<'_> { for defn in &self.input_dep_defns { writeln!(f, "{defn}")?; } + for defn in &self.model_dep_defns { + writeln!(f, "{defn}")?; + } for defn in &self.time_dep_defns { writeln!(f, "{defn}")?; } @@ -107,6 +111,7 @@ impl<'s> DiscreteModel<'s> { out: None, constant_defns: Vec::new(), input_dep_defns: Vec::new(), + model_dep_defns: Vec::new(), time_dep_defns: Vec::new(), state_dep_defns: Vec::new(), state_dep_post_f_defns: Vec::new(), @@ -538,6 +543,12 @@ impl<'s> DiscreteModel<'s> { && !dependent_on_model { ret.constant_defns.push(built); + } else if dependent_on_model + && !dependent_on_time + && !dependent_on_state + && !dependent_on_dudt + { + ret.model_dep_defns.push(built); } else if !dependent_on_time { ret.input_dep_defns.push(built); } else if !dependent_on_state && !dependent_on_dudt { @@ -868,6 +879,7 @@ impl<'s> DiscreteModel<'s> { out: Some(out_array), constant_defns, input_dep_defns: Vec::new(), // todo: need to implement + model_dep_defns: Vec::new(), time_dep_defns, state_dep_defns, state_dep_post_f_defns, @@ -898,6 +910,10 @@ impl<'s> DiscreteModel<'s> { self.input_dep_defns.as_ref() } + pub fn model_dep_defns(&self) -> &[Tensor<'_>] { + self.model_dep_defns.as_ref() + } + pub fn time_dep_defns(&self) -> &[Tensor<'_>] { self.time_dep_defns.as_ref() } diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index af4c4cf..776b2e0 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -346,14 +346,20 @@ impl Compiler { }); } - pub fn set_u0(&self, yy: &mut [T], data: &mut [T]) { + pub fn set_u0(&self, yy: &mut [T], data: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.with_threading(|i, dim| unsafe { - (self.jit_functions.set_u0)(yy.as_ptr() as *mut T, data.as_ptr() as *mut T, i, dim); + (self.jit_functions.set_u0)( + yy.as_ptr() as *mut T, + data.as_ptr() as *mut T, + model_index, + i, + dim, + ); }); } - pub fn set_u0_sgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T]) { + pub fn set_u0_sgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -368,13 +374,14 @@ impl Compiler { dyy.as_ptr() as *mut T, data.as_ptr(), ddata.as_ptr() as *mut T, + model_index, i, dim, ); }); } - pub fn set_u0_rgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T]) { + pub fn set_u0_rgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -389,13 +396,14 @@ impl Compiler { dyy.as_ptr() as *mut T, data.as_ptr(), ddata.as_ptr() as *mut T, + model_index, i, dim, ); }); } - pub fn set_u0_grad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T]) { + pub fn set_u0_grad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -407,6 +415,7 @@ impl Compiler { dyy.as_ptr() as *mut T, data.as_ptr(), ddata.as_ptr() as *mut T, + model_index, i, dim, ) @@ -414,7 +423,7 @@ impl Compiler { }) } - pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T]) { + pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T], model_index: u32) { if self.number_of_stop == 0 { panic!("Model does not have a stop function"); } @@ -427,13 +436,14 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, stop.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T]) { + pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T], model_index: u32) { if reset.is_empty() { return; } @@ -447,6 +457,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, reset.as_ptr() as *mut T, + model_index, i, dim, ) @@ -463,6 +474,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], + model_index: u32, ) { if dreset.is_empty() { return; @@ -483,6 +495,7 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, + model_index, i, dim, ) @@ -499,6 +512,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], + model_index: u32, ) { if dreset.is_empty() { return; @@ -523,6 +537,7 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, + model_index, i, dim, ) @@ -537,6 +552,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], + model_index: u32, ) { if dreset.is_empty() { return; @@ -559,6 +575,7 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, + model_index, i, dim, ) @@ -573,6 +590,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], + model_index: u32, ) { if dreset.is_empty() { return; @@ -595,13 +613,14 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T]) { + pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_data_len(data, "data"); @@ -611,6 +630,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, rr.as_ptr() as *mut T, + model_index, i, dim, ) @@ -625,7 +645,7 @@ impl Compiler { self.has_reset } - pub fn mass(&self, t: T, v: &[T], data: &mut [T], mv: &mut [T]) { + pub fn mass(&self, t: T, v: &[T], data: &mut [T], mv: &mut [T], model_index: u32) { if !self.has_mass { panic!("Model does not have a mass function"); } @@ -638,6 +658,7 @@ impl Compiler { v.as_ptr(), data.as_ptr() as *mut T, mv.as_ptr() as *mut T, + model_index, i, dim, ) @@ -662,6 +683,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -678,6 +700,7 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) @@ -694,6 +717,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -714,13 +738,14 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn mass_rgrad(&self, t: T, dv: &mut [T], data: &[T], ddata: &mut [T], dmv: &mut [T]) { + pub fn mass_rgrad(&self, t: T, dv: &mut [T], data: &[T], ddata: &mut [T], dmv: &mut [T], model_index: u32) { self.check_state_len(dv, "dv"); self.check_state_len(dmv, "dmv"); self.check_data_len(data, "data"); @@ -738,13 +763,14 @@ impl Compiler { ddata.as_ptr() as *mut T, std::ptr::null(), dmv.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn rhs_sgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) { + pub fn rhs_sgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -762,13 +788,14 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn rhs_srgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T]) { + pub fn rhs_srgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -786,13 +813,14 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, + model_index, i, dim, ) }); } - pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T]) { + pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T], model_index: u32) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); self.check_out_len(out, "out"); @@ -802,6 +830,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, out.as_ptr() as *mut T, + model_index, i, dim, ) @@ -818,6 +847,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -834,6 +864,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -850,6 +881,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -870,6 +902,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -884,6 +917,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -902,6 +936,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -916,6 +951,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], + model_index: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -934,6 +970,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, + model_index, i, dim, ) @@ -975,10 +1012,10 @@ impl Compiler { ) } - pub fn set_inputs(&self, inputs: &[T], data: &mut [T], model_index: u32) { + pub fn set_inputs(&self, inputs: &[T], data: &mut [T]) { self.check_inputs_len(inputs, "inputs"); self.check_data_len(data, "data"); - unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr(), model_index) }; + unsafe { (self.jit_functions.set_inputs)(inputs.as_ptr(), data.as_mut_ptr()) }; } pub fn get_inputs(&self, inputs: &mut [T], data: &[T]) { @@ -993,7 +1030,6 @@ impl Compiler { dinputs: &[T], data: &[T], ddata: &mut [T], - model_index: u32, ) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); @@ -1005,7 +1041,6 @@ impl Compiler { dinputs.as_ptr(), data.as_ptr(), ddata.as_mut_ptr(), - model_index, ) }; } @@ -1016,7 +1051,6 @@ impl Compiler { dinputs: &mut [T], data: &[T], ddata: &mut [T], - model_index: u32, ) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); @@ -1032,7 +1066,6 @@ impl Compiler { dinputs.as_mut_ptr(), data.as_ptr(), ddata.as_mut_ptr(), - model_index, ) }; } diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index 345f230..cb5f346 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -35,7 +35,6 @@ pub struct CraneliftModule { indices_id: DataId, constants_id: DataId, - model_index_id: DataId, thread_counter: Option, //triple: Triple, @@ -215,6 +214,7 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; let arg_names = &[ "t", @@ -224,6 +224,7 @@ impl CraneliftModule { "ddata", "out", "dout", + "model_index", "threadId", "threadDim", ]; @@ -231,6 +232,30 @@ impl CraneliftModule { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); if let Some(out) = model.out() { + let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + + for tensor in model.time_dep_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + + for tensor in model.state_dep_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.state_dep_post_f_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + codegen.jit_compile_tensor(out, None, true)?; } codegen.builder.ins().return_(&[]); @@ -251,6 +276,7 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; let arg_names = &[ "t", @@ -260,6 +286,7 @@ impl CraneliftModule { "ddata", "rr", "drr", + "model_index", "threadId", "threadDim", ]; @@ -268,6 +295,12 @@ impl CraneliftModule { // calculate time dependant definitions let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.time_dep_defns() { codegen.jit_compile_tensor(tensor, None, true)?; codegen.jit_compile_call_barrier(nbarrier); @@ -302,6 +335,7 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; let arg_names = &[ "t", @@ -311,6 +345,7 @@ impl CraneliftModule { "ddata", "reset", "dreset", + "model_index", "threadId", "threadDim", ]; @@ -319,6 +354,12 @@ impl CraneliftModule { if let Some(reset) = model.reset() { let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.time_dep_defns() { codegen.jit_compile_tensor(tensor, None, true)?; codegen.jit_compile_call_barrier(nbarrier); @@ -357,24 +398,11 @@ impl CraneliftModule { self.real_ptr_type, self.real_ptr_type, self.real_ptr_type, - self.int_type, ]; - let arg_names = &["inputs", "dinputs", "data", "ddata", "model_index"]; + let arg_names = &["inputs", "dinputs", "data", "ddata"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); - let model_index_ptr = codegen - .builder - .ins() - .global_value(codegen.int_ptr_type, codegen.model_index_global); - let model_index = codegen - .builder - .use_var(*codegen.variables.get("model_index").unwrap()); - codegen - .builder - .ins() - .store(codegen.mem_flags, model_index, model_index_ptr, 0); - let base_data_ptr = codegen.variables.get("ddata").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); codegen.jit_compile_inputs(model, base_data_ptr, true, false); @@ -414,8 +442,9 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; - let arg_names = &["u0", "du0", "data", "ddata", "threadId", "threadDim"]; + let arg_names = &["u0", "du0", "data", "ddata", "model_index", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -427,6 +456,12 @@ impl CraneliftModule { nbarrier += 1; } + for a in model.model_dep_defns() { + codegen.jit_compile_tensor(a, None, true)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + codegen.jit_compile_tensor( model.state(), Some(*codegen.variables.get("du0").unwrap()), @@ -496,11 +531,6 @@ impl CraneliftModule { let indices_id = module.declare_data("indices", Linkage::Local, false, false)?; module.define_data(indices_id, &data_description)?; - let mut data_description = DataDescription::new(); - data_description.define_zeroinit(int_type.bytes().try_into().unwrap()); - let model_index_id = module.declare_data("model_index", Linkage::Local, true, false)?; - module.define_data(model_index_id, &data_description)?; - let mut thread_counter = None; if threaded { let mut data_description = DataDescription::new(); @@ -517,7 +547,6 @@ impl CraneliftModule { module: Mutex::new(module), indices_id, constants_id, - model_index_id, int_type, real_type: real_type_cranelift, real_ptr_type: ptr_type, @@ -568,8 +597,9 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; - let arg_names = &["u0", "data", "threadId", "threadDim"]; + let arg_names = &["u0", "data", "model_index", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -581,6 +611,12 @@ impl CraneliftModule { nbarrier += 1; } + for a in model.model_dep_defns() { + codegen.jit_compile_tensor(a, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + codegen.jit_compile_tensor( model.state(), Some(*codegen.variables.get("u0").unwrap()), @@ -603,14 +639,21 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; - let arg_names = &["t", "u", "data", "out", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "out", "model_index", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); if let Some(out) = model.out() { // calculate time dependant definitions let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.time_dep_defns() { codegen.jit_compile_tensor(tensor, None, false)?; codegen.jit_compile_call_barrier(nbarrier); @@ -646,14 +689,21 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; - let arg_names = &["t", "u", "data", "root", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "root", "model_index", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); if let Some(stop) = model.stop() { // calculate time dependant definitions let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.time_dep_defns() { codegen.jit_compile_tensor(tensor, None, false)?; codegen.jit_compile_call_barrier(nbarrier); @@ -689,13 +739,20 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; - let arg_names = &["t", "u", "data", "reset", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "reset", "model_index", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); if let Some(reset) = model.reset() { let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.time_dep_defns() { codegen.jit_compile_tensor(tensor, None, false)?; codegen.jit_compile_call_barrier(nbarrier); @@ -730,13 +787,20 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; - let arg_names = &["t", "u", "data", "rr", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "rr", "model_index", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); // calculate time dependant definitions let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.time_dep_defns() { codegen.jit_compile_tensor(tensor, None, false)?; codegen.jit_compile_call_barrier(nbarrier); @@ -767,8 +831,9 @@ impl CraneliftModule { self.real_ptr_type, self.int_type, self.int_type, + self.int_type, ]; - let arg_names = &["t", "dudt", "data", "rr", "threadId", "threadDim"]; + let arg_names = &["t", "dudt", "data", "rr", "model_index", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -776,6 +841,12 @@ impl CraneliftModule { if model.state_dot().is_some() && model.lhs().is_some() { // calculate time dependant definitions let mut nbarrier = 0; + for tensor in model.model_dep_defns() { + codegen.jit_compile_tensor(tensor, None, false)?; + codegen.jit_compile_call_barrier(nbarrier); + nbarrier += 1; + } + for tensor in model.time_dep_defns() { codegen.jit_compile_tensor(tensor, None, false)?; codegen.jit_compile_call_barrier(nbarrier); @@ -916,23 +987,11 @@ impl CraneliftModule { } fn compile_set_inputs(&mut self, model: &DiscreteModel) -> Result { - let arg_types = &[self.real_ptr_type, self.real_ptr_type, self.int_type]; - let arg_names = &["inputs", "data", "model_index"]; + let arg_types = &[self.real_ptr_type, self.real_ptr_type]; + let arg_names = &["inputs", "data"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); - let model_index_ptr = codegen - .builder - .ins() - .global_value(codegen.int_ptr_type, codegen.model_index_global); - let model_index = codegen - .builder - .use_var(*codegen.variables.get("model_index").unwrap()); - codegen - .builder - .ins() - .store(codegen.mem_flags, model_index, model_index_ptr, 0); - let base_data_ptr = codegen.variables.get("data").unwrap(); let base_data_ptr = codegen.builder.use_var(*base_data_ptr); codegen.jit_compile_inputs(model, base_data_ptr, false, false); @@ -1169,7 +1228,6 @@ struct CraneliftCodeGen<'a, M: Module> { layout: &'a DataLayout, indices: GlobalValue, constants: GlobalValue, - model_index_global: GlobalValue, threaded: bool, } @@ -2469,12 +2527,6 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { .unwrap() .declare_data_in_func(module.constants_id, builder.func); - let model_index_global = module - .module - .lock() - .unwrap() - .declare_data_in_func(module.model_index_id, builder.func); - // Create the entry block, to start emitting code in. let entry_block = builder.create_block(); @@ -2507,7 +2559,6 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { functions: HashMap::new(), layout: &module.layout, threaded: module.threaded, - model_index_global, }; // insert arg vars @@ -2516,19 +2567,6 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { codegen.declare_variable(*arg_type, arg_name, val); } - if !codegen.variables.contains_key("model_index") { - let model_index_ptr = codegen - .builder - .ins() - .global_value(codegen.int_ptr_type, codegen.model_index_global); - let model_index = - codegen - .builder - .ins() - .load(codegen.int_type, codegen.mem_flags, model_index_ptr, 0); - codegen.declare_variable(codegen.int_type, "model_index", model_index); - } - // insert u if it exists in args if let Some(u) = codegen.variables.get("u") { let u_ptr = codegen.builder.use_var(*u); @@ -2579,6 +2617,7 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { // insert all tensors in data if it exists in args let tensors = model.input().into_iter(); let tensors = tensors.chain(model.input_dep_defns().iter()); + let tensors = tensors.chain(model.model_dep_defns().iter()); let tensors = tensors.chain(model.time_dep_defns().iter()); let tensors = tensors.chain(model.state_dep_defns().iter()); let tensors = tensors.chain(model.state_dep_post_f_defns().iter()); diff --git a/diffsl/src/execution/data_layout.rs b/diffsl/src/execution/data_layout.rs index be470c4..8ddb1be 100644 --- a/diffsl/src/execution/data_layout.rs +++ b/diffsl/src/execution/data_layout.rs @@ -169,6 +169,11 @@ impl DataLayout { .iter() .for_each(|i| add_tensor(i, true, false)); + model + .model_dep_defns() + .iter() + .for_each(|i| add_tensor(i, true, false)); + model .time_dep_defns() .iter() diff --git a/diffsl/src/execution/external/mod.rs b/diffsl/src/execution/external/mod.rs index 30d0c6c..ca1ae44 100644 --- a/diffsl/src/execution/external/mod.rs +++ b/diffsl/src/execution/external/mod.rs @@ -21,6 +21,7 @@ macro_rules! define_symbol_module { pub fn set_u0( u: *mut $ty, data: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -30,6 +31,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, reset: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -42,6 +44,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -54,6 +57,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -65,6 +69,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -76,6 +81,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -85,6 +91,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, rr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -97,6 +104,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -109,6 +117,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -120,6 +129,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -131,6 +141,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -140,6 +151,7 @@ macro_rules! define_symbol_module { v: *const $ty, data: *mut $ty, mv: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -152,6 +164,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, mv: *const $ty, dmv: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -161,6 +174,7 @@ macro_rules! define_symbol_module { du: *mut $ty, data: *const $ty, ddata: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -170,6 +184,7 @@ macro_rules! define_symbol_module { du: *mut $ty, data: *const $ty, ddata: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -179,6 +194,7 @@ macro_rules! define_symbol_module { du: *mut $ty, data: *const $ty, ddata: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -188,6 +204,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, out: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -200,6 +217,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -212,6 +230,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -223,6 +242,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -234,6 +254,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -243,6 +264,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, root: *mut $ty, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -259,7 +281,7 @@ macro_rules! define_symbol_module { has_reset: *mut UIntType, ); #[link_name = "set_inputs"] - pub fn set_inputs(inputs: *const $ty, data: *mut $ty, model_index: UIntType); + pub fn set_inputs(inputs: *const $ty, data: *mut $ty); #[link_name = "get_inputs"] pub fn get_inputs(inputs: *mut $ty, data: *const $ty); #[link_name = "set_inputs_grad"] @@ -268,7 +290,6 @@ macro_rules! define_symbol_module { dinputs: *const $ty, data: *const $ty, ddata: *mut $ty, - model_index: UIntType, ); #[link_name = "set_inputs_rgrad"] pub fn set_inputs_rgrad( @@ -276,7 +297,6 @@ macro_rules! define_symbol_module { dinputs: *mut $ty, data: *const $ty, ddata: *mut $ty, - model_index: UIntType, ); } } diff --git a/diffsl/src/execution/interface.rs b/diffsl/src/execution/interface.rs index 25dbfb5..ce0a4d9 100644 --- a/diffsl/src/execution/interface.rs +++ b/diffsl/src/execution/interface.rs @@ -12,6 +12,7 @@ pub type StopFunc = unsafe extern "C" fn( u: *const T, data: *mut T, root: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -20,6 +21,7 @@ pub type ResetFunc = unsafe extern "C" fn( u: *const T, data: *mut T, reset: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -31,6 +33,7 @@ pub type ResetGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -42,6 +45,7 @@ pub type ResetRevGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -52,6 +56,7 @@ pub type ResetSensGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -62,6 +67,7 @@ pub type ResetSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -70,6 +76,7 @@ pub type RhsFunc = unsafe extern "C" fn( u: *const T, data: *mut T, rr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -81,6 +88,7 @@ pub type RhsGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -92,6 +100,7 @@ pub type RhsRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -102,6 +111,7 @@ pub type RhsSensGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -112,6 +122,7 @@ pub type RhsSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -120,6 +131,7 @@ pub type MassFunc = unsafe extern "C" fn( v: *const T, data: *mut T, mv: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -131,16 +143,23 @@ pub type MassRevGradFunc = unsafe extern "C" fn( ddata: *mut T, mv: *const T, dmv: *mut T, + model_index: UIntType, + thread_id: UIntType, + thread_dim: UIntType, +); +pub type U0Func = unsafe extern "C" fn( + u: *mut T, + data: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); -pub type U0Func = - unsafe extern "C" fn(u: *mut T, data: *mut T, thread_id: UIntType, thread_dim: UIntType); pub type U0SensGradFunc = unsafe extern "C" fn( u: *const T, du: *mut T, data: *const T, ddata: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -149,6 +168,7 @@ pub type U0GradFunc = unsafe extern "C" fn( du: *mut T, data: *const T, ddata: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -157,6 +177,7 @@ pub type U0RevGradFunc = unsafe extern "C" fn( du: *mut T, data: *const T, ddata: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -165,6 +186,7 @@ pub type CalcOutFunc = unsafe extern "C" fn( u: *const T, data: *mut T, out: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -176,6 +198,7 @@ pub type CalcOutGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -187,6 +210,7 @@ pub type CalcOutRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -197,6 +221,7 @@ pub type CalcOutSensGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -207,6 +232,7 @@ pub type CalcOutSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, + model_index: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -219,22 +245,19 @@ pub type GetDimsFunc = unsafe extern "C" fn( has_mass: *mut UIntType, has_reset: *mut UIntType, ); -pub type SetInputsFunc = - unsafe extern "C" fn(inputs: *const T, data: *mut T, model_index: UIntType); +pub type SetInputsFunc = unsafe extern "C" fn(inputs: *const T, data: *mut T); pub type GetInputsFunc = unsafe extern "C" fn(inputs: *mut T, data: *const T); pub type SetInputsGradFunc = unsafe extern "C" fn( inputs: *const T, dinputs: *const T, data: *const T, ddata: *mut T, - model_index: UIntType, ); pub type SetInputsRevGradFunc = unsafe extern "C" fn( inputs: *const T, dinputs: *mut T, data: *const T, ddata: *mut T, - model_index: UIntType, ); pub type SetIdFunc = unsafe extern "C" fn(id: *mut T); pub type GetTensorFunc = diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 903bdc1..1d0a2fc 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -1413,6 +1413,9 @@ impl<'ctx> CodeGen<'ctx> { } fn insert_model_index(&mut self) { + if self.variables.contains_key("model_index") { + return; + } self.insert_param("model_index", self.globals.model_index.as_pointer_value()); } From 79df6421a339713016701a0f8e7308557477860a Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sun, 15 Mar 2026 20:40:45 +0000 Subject: [PATCH 2/4] finish model_index time dep --- diffsl/src/execution/compiler.rs | 221 +++++++++++++------ diffsl/src/execution/llvm/codegen.rs | 198 ++++++++++++----- diffsl/tests/pybamm_dfn.rs | 8 +- diffsl/tests/support/external_test_macros.rs | 94 ++++++-- 4 files changed, 369 insertions(+), 152 deletions(-) diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index 776b2e0..64376ab 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -359,7 +359,14 @@ impl Compiler { }); } - pub fn set_u0_sgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T], model_index: u32) { + pub fn set_u0_sgrad( + &self, + yy: &[T], + dyy: &mut [T], + data: &[T], + ddata: &mut [T], + model_index: u32, + ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -381,7 +388,14 @@ impl Compiler { }); } - pub fn set_u0_rgrad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T], model_index: u32) { + pub fn set_u0_rgrad( + &self, + yy: &[T], + dyy: &mut [T], + data: &[T], + ddata: &mut [T], + model_index: u32, + ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -403,7 +417,14 @@ impl Compiler { }); } - pub fn set_u0_grad(&self, yy: &[T], dyy: &mut [T], data: &[T], ddata: &mut [T], model_index: u32) { + pub fn set_u0_grad( + &self, + yy: &[T], + dyy: &mut [T], + data: &[T], + ddata: &mut [T], + model_index: u32, + ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); self.check_data_len(data, "data"); @@ -745,7 +766,15 @@ impl Compiler { }); } - pub fn mass_rgrad(&self, t: T, dv: &mut [T], data: &[T], ddata: &mut [T], dmv: &mut [T], model_index: u32) { + pub fn mass_rgrad( + &self, + t: T, + dv: &mut [T], + data: &[T], + ddata: &mut [T], + dmv: &mut [T], + model_index: u32, + ) { self.check_state_len(dv, "dv"); self.check_state_len(dmv, "dmv"); self.check_data_len(data, "data"); @@ -770,7 +799,16 @@ impl Compiler { }); } - pub fn rhs_sgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T], model_index: u32) { + pub fn rhs_sgrad( + &self, + t: T, + yy: &[T], + data: &[T], + ddata: &mut [T], + rr: &[T], + drr: &mut [T], + model_index: u32, + ) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -795,7 +833,16 @@ impl Compiler { }); } - pub fn rhs_srgrad(&self, t: T, yy: &[T], data: &[T], ddata: &mut [T], rr: &[T], drr: &mut [T], model_index: u32) { + pub fn rhs_srgrad( + &self, + t: T, + yy: &[T], + data: &[T], + ddata: &mut [T], + rr: &[T], + drr: &mut [T], + model_index: u32, + ) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_state_len(drr, "drr"); @@ -1024,13 +1071,7 @@ impl Compiler { unsafe { (self.jit_functions.get_inputs)(inputs.as_mut_ptr(), data.as_ptr()) }; } - pub fn set_inputs_grad( - &self, - inputs: &[T], - dinputs: &[T], - data: &[T], - ddata: &mut [T], - ) { + pub fn set_inputs_grad(&self, inputs: &[T], dinputs: &[T], data: &[T], ddata: &mut [T]) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); self.check_data_len(data, "data"); @@ -1045,13 +1086,7 @@ impl Compiler { }; } - pub fn set_inputs_rgrad( - &self, - inputs: &[T], - dinputs: &mut [T], - data: &[T], - ddata: &mut [T], - ) { + pub fn set_inputs_rgrad(&self, inputs: &[T], dinputs: &mut [T], data: &[T], ddata: &mut [T]) { self.check_inputs_len(inputs, "inputs"); self.check_inputs_len(dinputs, "dinputs"); self.check_data_len(data, "data"); @@ -1175,9 +1210,9 @@ mod tests { assert_relative_eq!(a2[0], T::zero()); // set the inputs and u0 let inputs = vec![T::one()]; - compiler.set_inputs(&inputs, data.as_mut_slice(), 0); + compiler.set_inputs(&inputs, data.as_mut_slice()); let mut u0 = vec![T::zero()]; - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); // now a and a2 should be set let a = compiler.get_tensor_data("in", &data).unwrap(); let a2 = compiler.get_tensor_data("a2", &data).unwrap(); @@ -1217,13 +1252,14 @@ mod tests { let mut u0 = vec![T::zero()]; let mut res = vec![T::zero()]; let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); assert_relative_eq!(u0.as_slice(), vec![T::one()].as_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice()); } @@ -1268,18 +1304,20 @@ mod tests { let mut res = vec![T::zero()]; let mut stop = vec![T::zero()]; let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); compiler.calc_stop( T::zero(), u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), + 0, ); assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap()); assert_eq!(stop.len(), 1); @@ -1332,12 +1370,13 @@ mod tests { let mut u0 = vec![T::zero(), T::zero()]; let mut reset = vec![T::zero(), T::zero()]; let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.reset( T::zero(), u0.as_slice(), data.as_mut_slice(), reset.as_mut_slice(), + 0, ); assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap()); @@ -1381,10 +1420,10 @@ mod tests { let mut data = compiler.get_new_data(); let inputs = vec![T::from_f64(3.0).unwrap()]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); let mut yy = vec![T::zero(), T::zero()]; - compiler.set_u0(yy.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(yy.as_mut_slice(), data.as_mut_slice(), 0); let mut reset = vec![T::zero(), T::zero()]; compiler.reset( @@ -1392,6 +1431,7 @@ mod tests { yy.as_slice(), data.as_mut_slice(), reset.as_mut_slice(), + 0, ); assert_relative_eq!(reset[0], T::from_f64(9.0).unwrap()); assert_relative_eq!(reset[1], T::from_f64(5.0).unwrap()); @@ -1403,7 +1443,6 @@ mod tests { dinputs.as_slice(), data.as_slice(), ddata.as_mut_slice(), - 0, ); let dyy = vec![T::one(), T::zero()]; let mut dreset = vec![T::zero(), T::zero()]; @@ -1415,6 +1454,7 @@ mod tests { ddata.as_mut_slice(), reset.as_slice(), dreset.as_mut_slice(), + 0, ); assert_relative_eq!(dreset[0], T::from_f64(3.0).unwrap()); assert_relative_eq!(dreset[1], T::from_f64(1.0).unwrap()); @@ -1431,6 +1471,7 @@ mod tests { ddata_rev.as_mut_slice(), reset.as_slice(), dreset_rev.as_mut_slice(), + 0, ); assert_relative_eq!(dyy_rev[0], T::from_f64(2.0).unwrap()); assert_relative_eq!(dyy_rev[1], T::one()); @@ -1441,13 +1482,12 @@ mod tests { dinputs_rev.as_mut_slice(), data.as_slice(), ddata_rev.as_mut_slice(), - 0, ); assert_relative_eq!(dinputs_rev[0], T::from_f64(2.0).unwrap()); let mut ddata_s = compiler.get_new_data(); let dinputs_s = vec![T::one(); inputs.len()]; - compiler.set_inputs(dinputs_s.as_slice(), ddata_s.as_mut_slice(), 0); + compiler.set_inputs(dinputs_s.as_slice(), ddata_s.as_mut_slice()); let mut dreset_s = vec![T::zero(), T::zero()]; compiler.reset_sgrad( T::zero(), @@ -1456,6 +1496,7 @@ mod tests { ddata_s.as_mut_slice(), reset.as_slice(), dreset_s.as_mut_slice(), + 0, ); assert_relative_eq!(dreset_s[0], T::one()); assert_relative_eq!(dreset_s[1], T::one()); @@ -1469,6 +1510,7 @@ mod tests { ddata_sr.as_mut_slice(), reset.as_slice(), dreset_sr.as_mut_slice(), + 0, ); let mut dinputs_sr = vec![T::zero(); inputs.len()]; @@ -1477,7 +1519,6 @@ mod tests { dinputs_sr.as_mut_slice(), data.as_slice(), ddata_sr.as_mut_slice(), - 0, ); assert_relative_eq!(dinputs_sr[0], T::from_f64(2.0).unwrap()); } @@ -1511,13 +1552,14 @@ mod tests { let mut u0 = vec![T::zero()]; let mut reset: Vec = vec![]; let mut data = compiler.get_new_data(); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.reset( T::zero(), u0.as_slice(), data.as_mut_slice(), reset.as_mut_slice(), + 0, ); assert_eq!(reset.len(), 0); } @@ -1566,33 +1608,37 @@ mod tests { let mut stop1 = vec![T::zero(); 1]; let mut data = compiler.get_new_data(); - compiler.set_inputs(&[], data.as_mut_slice(), 0); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_inputs(&[], data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr0.as_mut_slice(), + 0, ); compiler.calc_stop( T::zero(), u0.as_slice(), data.as_mut_slice(), stop0.as_mut_slice(), + 0, ); - compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.set_inputs(&[], data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr1.as_mut_slice(), + 1, ); compiler.calc_stop( T::zero(), u0.as_slice(), data.as_mut_slice(), stop1.as_mut_slice(), + 1, ); assert_relative_eq!(u0[0], T::from_f64(1.0).unwrap()); @@ -1638,20 +1684,22 @@ mod tests { let mut ddata0 = compiler.get_new_data(); let mut ddata1 = compiler.get_new_data(); - compiler.set_inputs(&[], data.as_mut_slice(), 0); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_inputs(&[], data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr0.as_mut_slice(), + 0, ); - compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.set_inputs(&[], data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr1.as_mut_slice(), + 1, ); assert_relative_eq!(rr0[0], T::from_f64(3.0).unwrap()); @@ -1659,7 +1707,7 @@ mod tests { let dyy0 = vec![T::one(), T::zero()]; let dyy1 = vec![T::zero(), T::one()]; - compiler.set_inputs(&[], data.as_mut_slice(), 0); + compiler.set_inputs(&[], data.as_mut_slice()); compiler.rhs_grad( T::zero(), u0.as_slice(), @@ -1668,8 +1716,9 @@ mod tests { ddata0.as_mut_slice(), rr0.as_slice(), drr0.as_mut_slice(), + 0, ); - compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.set_inputs(&[], data.as_mut_slice()); compiler.rhs_grad( T::zero(), u0.as_slice(), @@ -1678,6 +1727,7 @@ mod tests { ddata1.as_mut_slice(), rr1.as_slice(), drr1.as_mut_slice(), + 1, ); assert_relative_eq!(drr0[0], T::one()); @@ -1711,20 +1761,22 @@ mod tests { let mut rr1 = vec![T::zero(); 2]; let mut data = compiler.get_new_data(); - compiler.set_inputs(&[], data.as_mut_slice(), 0); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_inputs(&[], data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr0.as_mut_slice(), + 0, ); - compiler.set_inputs(&[], data.as_mut_slice(), 1); + compiler.set_inputs(&[], data.as_mut_slice()); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), rr1.as_mut_slice(), + 1, ); assert_relative_eq!(rr0[0], T::from_f64(-1.0).unwrap()); @@ -1756,13 +1808,14 @@ mod tests { let mut u0 = vec![T::one()]; let mut data = compiler.get_new_data(); // need this to set the constants - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); let mut out = vec![T::zero()]; compiler.calc_out( T::zero(), u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), + 0, ); assert_relative_eq!(out[0], T::from_f64(2.).unwrap()); u0[0] = T::from_f64(2.).unwrap(); @@ -1771,6 +1824,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), + 0, ); assert_relative_eq!(out[0], T::from_f64(4.).unwrap()); let mut stop = vec![T::zero()]; @@ -1779,6 +1833,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), + 0, ); assert_relative_eq!(stop[0], T::from_f64(3.5).unwrap()); u0[0] = T::from_f64(0.5).unwrap(); @@ -1787,6 +1842,7 @@ mod tests { u0.as_slice(), data.as_mut_slice(), stop.as_mut_slice(), + 0, ); assert_relative_eq!(stop[0], T::from_f64(0.5).unwrap()); } @@ -1863,19 +1919,21 @@ mod tests { let mut results = Vec::new(); let inputs = vec![T::one(); n_inputs]; let mut out = vec![T::zero(); n_outputs]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); compiler.calc_out( T::zero(), u0.as_slice(), data.as_mut_slice(), out.as_mut_slice(), + 0, ); let (tensor_len, tensor_is_constant) = if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, data.as_slice()) { @@ -1905,13 +1963,13 @@ mod tests { dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), - 0, ); compiler.set_u0_grad( u0.as_mut_slice(), du0.as_mut_slice(), data.as_mut_slice(), ddata.as_mut_slice(), + 0, ); compiler.rhs_grad( T::zero(), @@ -1921,6 +1979,7 @@ mod tests { ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice(), + 0, ); compiler.calc_out_grad( T::zero(), @@ -1930,6 +1989,7 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); if let Some(tensor_data) = compiler.get_tensor_data(tensor_name, ddata.as_slice()) { results.push(tensor_data.to_vec()); @@ -1958,6 +2018,7 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); compiler.rhs_rgrad( T::zero(), @@ -1967,12 +2028,14 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), + 0, ); compiler.set_u0_rgrad( u0.as_mut_slice(), du0.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); compiler.get_inputs(dinputs.as_mut_slice(), ddata.as_slice()); results.push(dinputs.to_vec()); @@ -1981,7 +2044,7 @@ mod tests { let mut ddata = compiler.get_new_data(); let mut dres = vec![T::zero(); n_states]; let dinputs = vec![T::one(); n_inputs]; - compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0); + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); compiler.rhs_sgrad( T::zero(), u0.as_slice(), @@ -1989,6 +2052,7 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), + 0, ); results.push( compiler @@ -2000,7 +2064,7 @@ mod tests { // forward mode sens (calc_out) let mut ddata = compiler.get_new_data(); let dinputs = vec![T::one(); n_inputs]; - compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice(), 0); + compiler.set_inputs(dinputs.as_slice(), ddata.as_mut_slice()); compiler.calc_out_sgrad( T::zero(), u0.as_slice(), @@ -2008,6 +2072,7 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); results.push( compiler @@ -2031,13 +2096,13 @@ mod tests { ddata.as_mut_slice(), res.as_slice(), dres.as_mut_slice(), + 0, ); compiler.set_inputs_rgrad( inputs.as_slice(), dinputs.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), - 0, ); results.push(dinputs.to_vec()); @@ -2055,13 +2120,13 @@ mod tests { ddata.as_mut_slice(), out.as_slice(), dout.as_mut_slice(), + 0, ); compiler.set_inputs_rgrad( inputs.as_slice(), dinputs.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), - 0, ); results.push(dinputs.to_vec()); } else { @@ -2584,13 +2649,14 @@ mod tests { let (_n_states, n_inputs, _n_outputs, _n_data, _n_stop, _has_mass, _has_reset) = compiler.get_dims(); let inputs = vec![T::from_f64(2.).unwrap(); n_inputs]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); for _i in 0..3 { @@ -2600,13 +2666,13 @@ mod tests { dinputs.as_slice(), data.as_mut_slice(), ddata.as_mut_slice(), - 0, ); compiler.set_u0_grad( u0.as_mut_slice(), du0.as_mut_slice(), data.as_mut_slice(), ddata.as_mut_slice(), + 0, ); compiler.rhs_grad( T::zero(), @@ -2616,6 +2682,7 @@ mod tests { ddata.as_mut_slice(), res.as_mut_slice(), dres.as_mut_slice(), + 0, ); assert_relative_eq!(dres.as_slice(), vec![T::from_f64(8.).unwrap()].as_slice()); } @@ -2677,10 +2744,22 @@ mod tests { let mut data = compiler.get_new_data(); let (_n_states, _n_inputs, _n_outputs, _n_data, _n_stop, _has_mass, _has_reset) = compiler.get_dims(); - compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); - compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + compiler.set_u0(u.as_mut_slice(), data.as_mut_slice(), 0); + compiler.rhs( + 0.0, + u.as_slice(), + data.as_mut_slice(), + res.as_mut_slice(), + 0, + ); assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice()); - compiler.rhs(0.0, u.as_slice(), data.as_mut_slice(), res.as_mut_slice()); + compiler.rhs( + 0.0, + u.as_slice(), + data.as_mut_slice(), + res.as_mut_slice(), + 0, + ); assert_relative_eq!(res.as_slice(), vec![3.0, 0.0, 0.0, 3.0].as_slice()); } @@ -2731,7 +2810,7 @@ mod tests { let mut data = compiler.get_new_data(); let inputs = vec![1.1]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.1].as_slice()); @@ -2741,20 +2820,20 @@ mod tests { assert_eq!(id, vec![1.0, 0.0]); let mut u = vec![0., 0.]; - compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u.as_mut_slice(), data.as_mut_slice(), 0); assert_relative_eq!(u.as_slice(), vec![1., 2.].as_slice()); let mut rr = vec![1., 1.]; - compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); + compiler.rhs(0., u.as_slice(), data.as_mut_slice(), rr.as_mut_slice(), 0); assert_relative_eq!(rr.as_slice(), vec![0., 0.].as_slice()); let up = vec![2., 3.]; rr = vec![1., 1.]; - compiler.mass(0., up.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); + compiler.mass(0., up.as_slice(), data.as_mut_slice(), rr.as_mut_slice(), 0); assert_relative_eq!(rr.as_slice(), vec![2., 0.].as_slice()); let mut out = vec![0.; 3]; - compiler.calc_out(0., u.as_slice(), data.as_mut_slice(), out.as_mut_slice()); + compiler.calc_out(0., u.as_slice(), data.as_mut_slice(), out.as_mut_slice(), 0); assert_relative_eq!(out.as_slice(), vec![1., 2., 4.].as_slice()); } } @@ -2782,7 +2861,7 @@ mod tests { .unwrap(); let mut data = compiler.get_new_data(); let inputs = vec![1.0, 2.0, 3.0, 4.0]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice()); @@ -2798,7 +2877,7 @@ mod tests { .unwrap(); let mut data = compiler.get_new_data(); let inputs = vec![1.0, 2.0, 3.0, 4.0]; - compiler.set_inputs(inputs.as_slice(), data.as_mut_slice(), 0); + compiler.set_inputs(inputs.as_slice(), data.as_mut_slice()); let inputs = compiler.get_tensor_data("in", data.as_slice()).unwrap(); assert_relative_eq!(inputs, vec![1.0, 2.0, 3.0, 4.0].as_slice()); @@ -2825,10 +2904,10 @@ mod tests { .unwrap(); let mut data = compiler.get_new_data(); let mut u0 = vec![0.0, 0.0, 0.0]; - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); let mut mv = vec![0.0, 0.0, 0.0]; let mut v = vec![1.0, 1.0, 1.0]; - compiler.mass(0.0, v.as_slice(), data.as_mut_slice(), mv.as_mut_slice()); + compiler.mass(0.0, v.as_slice(), data.as_mut_slice(), mv.as_mut_slice(), 0); assert_relative_eq!(mv.as_slice(), vec![2.0, 1.0, 1.0].as_slice()); mv = vec![1.0, 1.0, 1.0]; let mut ddata = compiler.get_new_data(); @@ -2838,6 +2917,7 @@ mod tests { data.as_mut_slice(), ddata.as_mut_slice(), mv.as_mut_slice(), + 0, ); assert_relative_eq!(v.as_slice(), vec![2.0, 3.0, 2.0].as_slice()); } @@ -2866,13 +2946,14 @@ mod tests { let handle = thread::spawn(move || { let mut data = compiler_clone.get_new_data(); let mut u0 = vec![T::zero()]; - compiler_clone.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler_clone.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); let mut res = vec![T::zero()]; compiler_clone.rhs( T::zero(), u0.as_slice(), data.as_mut_slice(), res.as_mut_slice(), + 0, ); assert_relative_eq!(res.as_slice(), vec![-T::one()].as_slice()); }); @@ -2906,22 +2987,22 @@ mod tests { let mut ddata = compiler.get_new_data(); let a = vec![T::from_f64(0.6).unwrap()]; let da = vec![T::one()]; - compiler.set_inputs(a.as_slice(), data.as_mut_slice(), 0); + compiler.set_inputs(a.as_slice(), data.as_mut_slice()); compiler.set_inputs_grad( a.as_slice(), da.as_slice(), data.as_slice(), ddata.as_mut_slice(), - 0, ); let mut u0 = vec![T::zero()]; let mut du0 = vec![T::zero()]; - compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u0.as_mut_slice(), data.as_mut_slice(), 0); compiler.set_u0_sgrad( u0.as_mut_slice(), du0.as_mut_slice(), data.as_slice(), ddata.as_mut_slice(), + 0, ); assert_relative_eq!( u0.as_slice(), diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 1d0a2fc..474e38a 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -328,6 +328,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Forward, "set_u0_grad", @@ -342,6 +343,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Forward, "rhs_grad", @@ -356,6 +358,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Forward, "reset_grad", @@ -370,6 +373,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Forward, "calc_out_grad", @@ -379,7 +383,6 @@ impl CodegenModuleCompile for LlvmModule { &[ CompileGradientArgType::DupNoNeed, CompileGradientArgType::DupNoNeed, - CompileGradientArgType::Const, ], CompileMode::Forward, "set_inputs_grad", @@ -392,6 +395,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Reverse, "set_u0_rgrad", @@ -406,6 +410,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Reverse, "mass_rgrad", @@ -420,6 +425,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Reverse, "rhs_rgrad", @@ -433,6 +439,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Reverse, "reset_rgrad", @@ -446,6 +453,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::Reverse, "calc_out_rgrad", @@ -456,7 +464,6 @@ impl CodegenModuleCompile for LlvmModule { &[ CompileGradientArgType::DupNoNeed, CompileGradientArgType::DupNoNeed, - CompileGradientArgType::Const, ], CompileMode::Reverse, "set_inputs_rgrad", @@ -471,6 +478,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ForwardSens, "rhs_sgrad", @@ -485,6 +493,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ForwardSens, "reset_sgrad", @@ -497,6 +506,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ForwardSens, "set_u0_sgrad", @@ -511,6 +521,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ForwardSens, "calc_out_sgrad", @@ -524,6 +535,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ReverseSens, "calc_out_srgrad", @@ -538,6 +550,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ReverseSens, "rhs_srgrad", @@ -552,6 +565,7 @@ impl CodegenModuleCompile for LlvmModule { CompileGradientArgType::DupNoNeed, CompileGradientArgType::Const, CompileGradientArgType::Const, + CompileGradientArgType::Const, ], CompileMode::ReverseSens, "reset_srgrad", @@ -605,7 +619,6 @@ struct Globals<'ctx> { indices: Option>, constants: Option>, thread_counter: Option>, - model_index: GlobalValue<'ctx>, } impl<'ctx> Globals<'ctx> { @@ -671,19 +684,10 @@ impl<'ctx> Globals<'ctx> { indices.set_initializer(&indices_value); Some(indices) }; - let model_index = module.add_global( - int_type, - Some(AddressSpace::default()), - "enzyme_const_model_index", - ); - model_index.set_visibility(GlobalVisibility::Hidden); - model_index.set_constant(false); - model_index.set_initializer(&int_type.const_zero()); Self { indices, thread_counter, constants, - model_index, } } } @@ -1371,7 +1375,6 @@ impl<'ctx> CodeGen<'ctx> { } fn insert_data(&mut self, model: &DiscreteModel) { - self.insert_model_index(); self.insert_constants(model); if let Some(input) = model.input() { @@ -1380,6 +1383,9 @@ impl<'ctx> CodeGen<'ctx> { for tensor in model.input_dep_defns() { self.insert_tensor(tensor, false); } + for tensor in model.model_dep_defns() { + self.insert_tensor(tensor, false); + } for tensor in model.time_dep_defns() { self.insert_tensor(tensor, false); } @@ -1412,13 +1418,6 @@ impl<'ctx> CodeGen<'ctx> { } } - fn insert_model_index(&mut self) { - if self.variables.contains_key("model_index") { - return; - } - self.insert_param("model_index", self.globals.model_index.as_pointer_value()); - } - fn insert_param(&mut self, name: &str, value: PointerValue<'ctx>) { self.variables.insert(name.to_owned(), value); } @@ -2964,6 +2963,9 @@ impl<'ctx> CodeGen<'ctx> { .into_float_value(); let u = *self.get_param("u"); let data = *self.get_param("data"); + let model_index = self + .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + .into_int_value(); let thread_id = self .build_load(self.int_type, *self.get_param("thread_id"), "thread_id")? .into_int_value(); @@ -2978,6 +2980,7 @@ impl<'ctx> CodeGen<'ctx> { t.into(), u.into(), data.into(), + model_index.into(), thread_id.into(), thread_dim.into(), barrier_start.into(), @@ -2999,6 +3002,17 @@ impl<'ctx> CodeGen<'ctx> { self.compile_dep_defns(model, "calc_time_dep", model.time_dep_defns(), code) } + fn ensure_model_dep_fn<'m>( + &mut self, + model: &'m DiscreteModel, + code: Option<&str>, + ) -> Result> { + if let Some(function) = self.module.get_function("calc_model_dep") { + return Ok(function); + } + self.compile_dep_defns(model, "calc_model_dep", model.model_dep_defns(), code) + } + fn ensure_state_dep_fn<'m>( &mut self, model: &'m DiscreteModel, @@ -3055,7 +3069,7 @@ impl<'ctx> CodeGen<'ctx> { code: Option<&str>, ) -> Result> { self.clear(); - let fn_arg_names = &["u0", "data", "thread_id", "thread_dim"]; + let fn_arg_names = &["u0", "data", "model_index", "thread_id", "thread_dim"]; let function = self.add_function( "set_u0", fn_arg_names, @@ -3064,6 +3078,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3088,7 +3103,8 @@ impl<'ctx> CodeGen<'ctx> { self.insert_indices(); let mut nbarriers = 0; - let total_barriers = (model.input_dep_defns().len() + 1) as u64; + let total_barriers = + (model.input_dep_defns().len() + model.model_dep_defns().len() + 1) as u64; let total_barriers_val = self.int_type.const_int(total_barriers, false); #[allow(clippy::explicit_counter_loop)] for a in model.input_dep_defns() { @@ -3098,6 +3114,13 @@ impl<'ctx> CodeGen<'ctx> { nbarriers += 1; } + for a in model.model_dep_defns() { + self.jit_compile_tensor(a, Some(*self.get_var(a)), code)?; + let barrier_num = self.int_type.const_int(nbarriers + 1, false); + self.jit_compile_call_barrier(barrier_num, total_barriers_val); + nbarriers += 1; + } + self.jit_compile_tensor(model.state(), Some(*self.get_param("u0")), code)?; let barrier_num = self.int_type.const_int(nbarriers + 1, false); self.jit_compile_call_barrier(barrier_num, total_barriers_val); @@ -3121,11 +3144,20 @@ impl<'ctx> CodeGen<'ctx> { include_constants: bool, code: Option<&str>, ) -> Result> { + let model_dep_fn = self.ensure_model_dep_fn(model, code)?; let time_dep_fn = self.ensure_time_dep_fn(model, code)?; let state_dep_fn = self.ensure_state_dep_fn(model, code)?; let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); - let fn_arg_names = &["t", "u", "data", "out", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "u", + "data", + "out", + "model_index", + "thread_id", + "thread_dim", + ]; let function_name = if include_constants { "calc_out_full" } else { @@ -3141,6 +3173,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3172,7 +3205,8 @@ impl<'ctx> CodeGen<'ctx> { //self.compile_print_value("thread_dim", PrintValue::Int(thread_dim.into_int_value()))?; if let Some(out) = model.out() { let mut nbarriers = 0; - let mut total_barriers = (model.time_dep_defns().len() + let mut total_barriers = (model.model_dep_defns().len() + + model.time_dep_defns().len() + model.state_dep_defns().len() + model.state_dep_post_f_defns().len() + 1) as u64; @@ -3190,6 +3224,11 @@ impl<'ctx> CodeGen<'ctx> { } } + if !model.model_dep_defns().is_empty() { + self.build_dep_call(model_dep_fn, "model_dep", nbarriers, total_barriers)?; + nbarriers += model.model_dep_defns().len() as u64; + } + // calculate time dependant definitions if !model.time_dep_defns().is_empty() { self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?; @@ -3235,6 +3274,7 @@ impl<'ctx> CodeGen<'ctx> { "t", "u", "data", + "model_index", "thread_id", "thread_dim", "barrier_start", @@ -3251,6 +3291,7 @@ impl<'ctx> CodeGen<'ctx> { self.int_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3321,11 +3362,20 @@ impl<'ctx> CodeGen<'ctx> { model: &'m DiscreteModel, code: Option<&str>, ) -> Result> { + let model_dep_fn = self.ensure_model_dep_fn(model, code)?; let time_dep_fn = self.ensure_time_dep_fn(model, code)?; let state_dep_fn = self.ensure_state_dep_fn(model, code)?; let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); - let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "u", + "data", + "root", + "model_index", + "thread_id", + "thread_dim", + ]; let function = self.add_function( "calc_stop", fn_arg_names, @@ -3336,6 +3386,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3363,11 +3414,16 @@ impl<'ctx> CodeGen<'ctx> { if let Some(stop) = model.stop() { // calculate time dependant definitions let mut nbarriers = 0; - let total_barriers = (model.time_dep_defns().len() + let total_barriers = (model.model_dep_defns().len() + + model.time_dep_defns().len() + model.state_dep_defns().len() + model.state_dep_post_f_defns().len() + 1) as u64; let total_barriers_val = self.int_type.const_int(total_barriers, false); + if !model.model_dep_defns().is_empty() { + self.build_dep_call(model_dep_fn, "model_dep", nbarriers, total_barriers)?; + nbarriers += model.model_dep_defns().len() as u64; + } if !model.time_dep_defns().is_empty() { self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?; nbarriers += model.time_dep_defns().len() as u64; @@ -3406,11 +3462,20 @@ impl<'ctx> CodeGen<'ctx> { model: &'m DiscreteModel, code: Option<&str>, ) -> Result> { + let model_dep_fn = self.ensure_model_dep_fn(model, code)?; let time_dep_fn = self.ensure_time_dep_fn(model, code)?; let state_dep_fn = self.ensure_state_dep_fn(model, code)?; let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); - let fn_arg_names = &["t", "u", "data", "reset", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "u", + "data", + "reset", + "model_index", + "thread_id", + "thread_dim", + ]; let function = self.add_function( "reset", fn_arg_names, @@ -3421,6 +3486,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3446,11 +3512,16 @@ impl<'ctx> CodeGen<'ctx> { if let Some(reset) = model.reset() { let mut nbarriers = 0; - let total_barriers = (model.time_dep_defns().len() + let total_barriers = (model.model_dep_defns().len() + + model.time_dep_defns().len() + model.state_dep_defns().len() + model.state_dep_post_f_defns().len() + 1) as u64; let total_barriers_val = self.int_type.const_int(total_barriers, false); + if !model.model_dep_defns().is_empty() { + self.build_dep_call(model_dep_fn, "model_dep", nbarriers, total_barriers)?; + nbarriers += model.model_dep_defns().len() as u64; + } if !model.time_dep_defns().is_empty() { self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?; nbarriers += model.time_dep_defns().len() as u64; @@ -3489,10 +3560,19 @@ impl<'ctx> CodeGen<'ctx> { include_constants: bool, code: Option<&str>, ) -> Result> { + let model_dep_fn = self.ensure_model_dep_fn(model, code)?; let time_dep_fn = self.ensure_time_dep_fn(model, code)?; let state_dep_fn = self.ensure_state_dep_fn(model, code)?; self.clear(); - let fn_arg_names = &["t", "u", "data", "rr", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "u", + "data", + "rr", + "model_index", + "thread_id", + "thread_dim", + ]; let function_name = if include_constants { "rhs_full" } else { "rhs" }; let function = self.add_function( function_name, @@ -3504,6 +3584,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3529,8 +3610,10 @@ impl<'ctx> CodeGen<'ctx> { self.insert_indices(); let mut nbarriers = 0; - let mut total_barriers = - (model.time_dep_defns().len() + model.state_dep_defns().len() + 1) as u64; + let mut total_barriers = (model.model_dep_defns().len() + + model.time_dep_defns().len() + + model.state_dep_defns().len() + + 1) as u64; if include_constants { total_barriers += model.input_dep_defns().len() as u64; // calculate constant definitions @@ -3543,6 +3626,11 @@ impl<'ctx> CodeGen<'ctx> { } } + if !model.model_dep_defns().is_empty() { + self.build_dep_call(model_dep_fn, "model_dep", nbarriers, total_barriers)?; + nbarriers += model.model_dep_defns().len() as u64; + } + // calculate time dependant definitions if !model.time_dep_defns().is_empty() { self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?; @@ -3580,7 +3668,15 @@ impl<'ctx> CodeGen<'ctx> { code: Option<&str>, ) -> Result> { self.clear(); - let fn_arg_names = &["t", "dudt", "data", "rr", "thread_id", "thread_dim"]; + let fn_arg_names = &[ + "t", + "dudt", + "data", + "rr", + "model_index", + "thread_id", + "thread_dim", + ]; let function = self.add_function( "mass", fn_arg_names, @@ -3591,6 +3687,7 @@ impl<'ctx> CodeGen<'ctx> { self.real_ptr_type.into(), self.int_type.into(), self.int_type.into(), + self.int_type.into(), ], None, false, @@ -3622,9 +3719,17 @@ impl<'ctx> CodeGen<'ctx> { // calculate time dependant definitions let mut nbarriers = 0; - let total_barriers = - (model.time_dep_defns().len() + model.dstate_dep_defns().len() + 1) as u64; + let total_barriers = (model.model_dep_defns().len() + + model.time_dep_defns().len() + + model.dstate_dep_defns().len() + + 1) as u64; let total_barriers_val = self.int_type.const_int(total_barriers, false); + for tensor in model.model_dep_defns() { + self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?; + let barrier_num = self.int_type.const_int(nbarriers + 1, false); + self.jit_compile_call_barrier(barrier_num, total_barriers_val); + nbarriers += 1; + } for tensor in model.time_dep_defns() { self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?; let barrier_num = self.int_type.const_int(nbarriers + 1, false); @@ -4127,20 +4232,9 @@ impl<'ctx> CodeGen<'ctx> { ) -> Result> { self.clear(); let function_name = if is_get { "get_inputs" } else { "set_inputs" }; - let fn_arg_names: &[&str] = if is_get { - &["inputs", "data"] - } else { - &["inputs", "data", "model_index"] - }; - let fn_arg_types: &[BasicMetadataTypeEnum<'ctx>] = if is_get { - &[self.real_ptr_type.into(), self.real_ptr_type.into()] - } else { - &[ - self.real_ptr_type.into(), - self.real_ptr_type.into(), - self.int_type.into(), - ] - }; + let fn_arg_names: &[&str] = &["inputs", "data"]; + let fn_arg_types: &[BasicMetadataTypeEnum<'ctx>] = + &[self.real_ptr_type.into(), self.real_ptr_type.into()]; let function = self.add_function(function_name, fn_arg_names, fn_arg_types, None, false); let block = self.start_function(function, None); @@ -4150,14 +4244,6 @@ impl<'ctx> CodeGen<'ctx> { self.insert_param(name, alloca); } - if !is_get { - let model_index = self - .build_load(self.int_type, *self.get_param("model_index"), "model_index")? - .into_int_value(); - self.builder - .build_store(self.globals.model_index.as_pointer_value(), model_index)?; - } - if let Some(input) = model.input() { let name = input.name(); self.insert_tensor(input, false); diff --git a/diffsl/tests/pybamm_dfn.rs b/diffsl/tests/pybamm_dfn.rs index 91cbe52..f2d5d80 100644 --- a/diffsl/tests/pybamm_dfn.rs +++ b/diffsl/tests/pybamm_dfn.rs @@ -34,17 +34,17 @@ fn test_dfn_model_initialization() { let mut data = compiler.get_new_data(); let (n_states, n_inputs, _, _, _, _, _) = compiler.get_dims(); let inputs = vec![1.0; n_inputs]; - compiler.set_inputs(&inputs, &mut data, 0); + compiler.set_inputs(&inputs, &mut data); let mut u = vec![0.0; n_states]; - compiler.set_u0(&mut u, &mut data); + compiler.set_u0(&mut u, &mut data, 0); let mut rr = vec![0.0; n_states]; let t = 0.0; - compiler.rhs(t, &u, &mut data, &mut rr); + compiler.rhs(t, &u, &mut data, &mut rr, 0); let v = vec![1.; n_states]; let mut drr = vec![0.0; n_states]; let mut ddata = compiler.get_new_data(); println!("Computing rhs grad..."); // flush stdout to ensure the print appears before any potential panic std::io::stdout().flush().unwrap(); - compiler.rhs_grad(t, &u, &v, &data, &mut ddata, &rr, &mut drr); + compiler.rhs_grad(t, &u, &v, &data, &mut ddata, &rr, &mut drr, 0); } diff --git a/diffsl/tests/support/external_test_macros.rs b/diffsl/tests/support/external_test_macros.rs index 2b68baf..69e2e3d 100644 --- a/diffsl/tests/support/external_test_macros.rs +++ b/diffsl/tests/support/external_test_macros.rs @@ -19,6 +19,7 @@ macro_rules! define_external_test { pub unsafe extern "C" fn set_u0( u: *mut $ty, _data: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -33,6 +34,7 @@ macro_rules! define_external_test { u: *const $ty, data: *mut $ty, rr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -53,6 +55,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -75,6 +78,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -95,6 +99,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -114,6 +119,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -131,6 +137,7 @@ macro_rules! define_external_test { v: *const $ty, _data: *mut $ty, mv: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -149,6 +156,7 @@ macro_rules! define_external_test { _ddata: *mut $ty, _mv: *const $ty, dmv: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -164,6 +172,7 @@ macro_rules! define_external_test { _du: *mut $ty, _data: *const $ty, _ddata: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -175,6 +184,7 @@ macro_rules! define_external_test { _du: *mut $ty, _data: *const $ty, _ddata: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -186,6 +196,7 @@ macro_rules! define_external_test { _du: *mut $ty, _data: *const $ty, _ddata: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -197,6 +208,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, out: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -215,6 +227,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -234,6 +247,7 @@ macro_rules! define_external_test { _ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -251,6 +265,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -270,6 +285,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -287,6 +303,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, root: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -302,6 +319,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, reset: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -320,6 +338,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -339,6 +358,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -357,6 +377,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -375,6 +396,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, + _model_index: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -426,7 +448,7 @@ macro_rules! define_external_test { } #[no_mangle] - pub unsafe extern "C" fn set_inputs(inputs: *const $ty, data: *mut $ty, _model_index: u32) { + pub unsafe extern "C" fn set_inputs(inputs: *const $ty, data: *mut $ty) { if inputs.is_null() || data.is_null() { return; } @@ -447,7 +469,6 @@ macro_rules! define_external_test { dinputs: *const $ty, _data: *const $ty, ddata: *mut $ty, - _model_index: u32, ) { if dinputs.is_null() || ddata.is_null() { return; @@ -461,7 +482,6 @@ macro_rules! define_external_test { dinputs: *mut $ty, _data: *const $ty, ddata: *mut $ty, - _model_index: u32, ) { if dinputs.is_null() || ddata.is_null() { return; @@ -487,36 +507,45 @@ macro_rules! define_external_test { let mut data = vec![-1.0 as $ty; n_data]; let inputs = vec![1.0 as $ty; n_inputs]; - compiler.set_inputs(&inputs, &mut data, 0); + compiler.set_inputs(&inputs, &mut data); let mut inputs_out = vec![-2.0 as $ty; n_inputs]; compiler.get_inputs(&mut inputs_out, &data); assert_eq!(inputs_out, inputs); let mut u = vec![-2.0 as $ty; n_states]; - compiler.set_u0(&mut u, &mut data); + compiler.set_u0(&mut u, &mut data, 0); assert_eq!(u[0], 1.0 as $ty); let mut out = vec![-3.0 as $ty; n_outputs]; - compiler.calc_out(0.0 as $ty, &u, &mut data, &mut out); + compiler.calc_out(0.0 as $ty, &u, &mut data, &mut out, 0); assert_eq!(out[0], u[0]); let mut rr = vec![-4.0 as $ty; n_states]; - compiler.rhs(0.0 as $ty, &u, &mut data, &mut rr); + compiler.rhs(0.0 as $ty, &u, &mut data, &mut rr, 0); assert_eq!(rr[0], 0.0 as $ty); let mut stop = vec![-5.0 as $ty; n_stop]; - compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop); + compiler.calc_stop(0.0 as $ty, &u, &mut data, &mut stop, 0); assert_eq!(stop[0], 0.5 as $ty); let mut reset = vec![-5.5 as $ty; n_states]; - compiler.reset(0.0 as $ty, &u, &mut data, &mut reset); + compiler.reset(0.0 as $ty, &u, &mut data, &mut reset, 0); assert_eq!(reset[0], 2.0 as $ty); let du = vec![1.0 as $ty; n_states]; let mut ddata = vec![-8.0 as $ty; n_data]; let mut dreset = vec![-5.75 as $ty; n_states]; - compiler.reset_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &reset, &mut dreset); + compiler.reset_grad( + 0.0 as $ty, + &u, + &du, + &data, + &mut ddata, + &reset, + &mut dreset, + 0, + ); assert_eq!(dreset[0], 2.0 as $ty); let mut du_reset_rev = vec![-5.85 as $ty; n_states]; @@ -530,6 +559,7 @@ macro_rules! define_external_test { &mut ddata_reset_rev, &reset, &mut dreset_rev, + 0, ); assert!((du_reset_rev[0] - (-3.85 as $ty)).abs() < (1e-6 as $ty)); @@ -542,6 +572,7 @@ macro_rules! define_external_test { &mut ddata_reset_s, &reset, &mut dreset_s, + 0, ); assert_eq!(dreset_s[0], 0.0 as $ty); @@ -554,11 +585,12 @@ macro_rules! define_external_test { &mut ddata_reset_sr, &reset, &mut dreset_sr, + 0, ); assert_eq!(dreset_sr[0], 0.0 as $ty); let mut mv = vec![-6.0 as $ty; n_states]; - compiler.mass(0.0 as $ty, &u, &mut data, &mut mv); + compiler.mass(0.0 as $ty, &u, &mut data, &mut mv, 0); assert_eq!(mv[0], 1.0 as $ty); let mut id = vec![-7.0 as $ty; n_states]; @@ -566,56 +598,74 @@ macro_rules! define_external_test { assert_eq!(id[0], 42.0 as $ty); let mut drr = vec![-9.0 as $ty; n_states]; - compiler.rhs_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &rr, &mut drr); + compiler.rhs_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &rr, &mut drr, 0); assert_eq!(drr[0], -1.0 as $ty); assert_eq!(ddata[0], 0.0 as $ty); let mut dout = vec![-10.0 as $ty; n_outputs]; - compiler.calc_out_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &out, &mut dout); + compiler.calc_out_grad(0.0 as $ty, &u, &du, &data, &mut ddata, &out, &mut dout, 0); assert_eq!(dout[0], 1.0 as $ty); assert_eq!(ddata[0], 0.0 as $ty); let mut dinputs = vec![1.0 as $ty; n_inputs]; - compiler.set_inputs_grad(&inputs, &dinputs, &data, &mut ddata, 0); + compiler.set_inputs_grad(&inputs, &dinputs, &data, &mut ddata); assert_eq!(ddata[0], 1.0 as $ty); let mut du_rev = vec![-11.0 as $ty; n_states]; let mut ddata_rev = vec![-12.0 as $ty; n_data]; let mut drr_rev = vec![1.0 as $ty; n_states]; - compiler.rhs_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &rr, &mut drr_rev); + compiler.rhs_rgrad( + 0.0 as $ty, + &u, + &mut du_rev, + &data, + &mut ddata_rev, + &rr, + &mut drr_rev, + 0, + ); assert_eq!(du_rev[0], -12.0 as $ty); assert_eq!(ddata_rev[0], -12.0 as $ty); let mut dv = vec![-13.0 as $ty; n_states]; let mut dmv = vec![1.0 as $ty; n_states]; - compiler.mass_rgrad(0.0 as $ty, &mut dv, &data, &mut ddata_rev, &mut dmv); + compiler.mass_rgrad(0.0 as $ty, &mut dv, &data, &mut ddata_rev, &mut dmv, 0); assert_eq!(dv[0], -12.0 as $ty); let mut dout_rev = vec![1.0 as $ty; n_outputs]; - compiler.calc_out_rgrad(0.0 as $ty, &u, &mut du_rev, &data, &mut ddata_rev, &out, &mut dout_rev); + compiler.calc_out_rgrad( + 0.0 as $ty, + &u, + &mut du_rev, + &data, + &mut ddata_rev, + &out, + &mut dout_rev, + 0, + ); assert_eq!(du_rev[0], -11.0 as $ty); - compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev, 0); + compiler.set_inputs_rgrad(&inputs, &mut dinputs, &data, &mut ddata_rev); assert_eq!(dinputs[0], -11.0 as $ty); let mut ddata_s = vec![-14.0 as $ty; n_data]; let mut drr_s = vec![-15.0 as $ty; n_states]; - compiler.rhs_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &rr, &mut drr_s); + compiler.rhs_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &rr, &mut drr_s, 0); assert_eq!(drr_s[0], 0.0 as $ty); assert_eq!(ddata_s[0], 0.0 as $ty); let mut dout_s = vec![-16.0 as $ty; n_outputs]; - compiler.calc_out_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &out, &mut dout_s); + compiler.calc_out_sgrad(0.0 as $ty, &u, &data, &mut ddata_s, &out, &mut dout_s, 0); assert_eq!(dout_s[0], 0.0 as $ty); let mut ddata_sr = vec![-17.0 as $ty; n_data]; let mut drr_sr = vec![-18.0 as $ty; n_states]; - compiler.rhs_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &rr, &mut drr_sr); + compiler.rhs_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &rr, &mut drr_sr, 0); assert_eq!(drr_sr[0], 0.0 as $ty); assert_eq!(ddata_sr[0], 0.0 as $ty); let mut dout_sr = vec![-19.0 as $ty; n_outputs]; - compiler.calc_out_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &out, &mut dout_sr); + compiler.calc_out_srgrad(0.0 as $ty, &u, &data, &mut ddata_sr, &out, &mut dout_sr, 0); assert_eq!(dout_sr[0], 0.0 as $ty); } }; From aec32219782bc86e3a83cd7dd6ca142aa708c92f Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sun, 15 Mar 2026 20:45:34 +0000 Subject: [PATCH 3/4] model-index -> model --- diffsl/src/execution/compiler.rs | 106 +++++++++---------- diffsl/src/execution/cranelift/codegen.rs | 32 +++--- diffsl/src/execution/external/mod.rs | 44 ++++---- diffsl/src/execution/interface.rs | 44 ++++---- diffsl/src/execution/llvm/codegen.rs | 28 ++--- diffsl/tests/support/external_test_macros.rs | 44 ++++---- 6 files changed, 149 insertions(+), 149 deletions(-) diff --git a/diffsl/src/execution/compiler.rs b/diffsl/src/execution/compiler.rs index 64376ab..e6af629 100644 --- a/diffsl/src/execution/compiler.rs +++ b/diffsl/src/execution/compiler.rs @@ -346,13 +346,13 @@ impl Compiler { }); } - pub fn set_u0(&self, yy: &mut [T], data: &mut [T], model_index: u32) { + pub fn set_u0(&self, yy: &mut [T], data: &mut [T], model: u32) { self.check_state_len(yy, "yy"); self.with_threading(|i, dim| unsafe { (self.jit_functions.set_u0)( yy.as_ptr() as *mut T, data.as_ptr() as *mut T, - model_index, + model, i, dim, ); @@ -365,7 +365,7 @@ impl Compiler { dyy: &mut [T], data: &[T], ddata: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -381,7 +381,7 @@ impl Compiler { dyy.as_ptr() as *mut T, data.as_ptr(), ddata.as_ptr() as *mut T, - model_index, + model, i, dim, ); @@ -394,7 +394,7 @@ impl Compiler { dyy: &mut [T], data: &[T], ddata: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -410,7 +410,7 @@ impl Compiler { dyy.as_ptr() as *mut T, data.as_ptr(), ddata.as_ptr() as *mut T, - model_index, + model, i, dim, ); @@ -423,7 +423,7 @@ impl Compiler { dyy: &mut [T], data: &[T], ddata: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -436,7 +436,7 @@ impl Compiler { dyy.as_ptr() as *mut T, data.as_ptr(), ddata.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -444,7 +444,7 @@ impl Compiler { }) } - pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T], model_index: u32) { + pub fn calc_stop(&self, t: T, yy: &[T], data: &mut [T], stop: &mut [T], model: u32) { if self.number_of_stop == 0 { panic!("Model does not have a stop function"); } @@ -457,14 +457,14 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, stop.as_ptr() as *mut T, - model_index, + model, i, dim, ) }); } - pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T], model_index: u32) { + pub fn reset(&self, t: T, yy: &[T], data: &mut [T], reset: &mut [T], model: u32) { if reset.is_empty() { return; } @@ -478,7 +478,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, reset.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -495,7 +495,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], - model_index: u32, + model: u32, ) { if dreset.is_empty() { return; @@ -516,7 +516,7 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -533,7 +533,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], - model_index: u32, + model: u32, ) { if dreset.is_empty() { return; @@ -558,7 +558,7 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -573,7 +573,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], - model_index: u32, + model: u32, ) { if dreset.is_empty() { return; @@ -596,7 +596,7 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -611,7 +611,7 @@ impl Compiler { ddata: &mut [T], reset: &[T], dreset: &mut [T], - model_index: u32, + model: u32, ) { if dreset.is_empty() { return; @@ -634,14 +634,14 @@ impl Compiler { ddata.as_ptr() as *mut T, reset.as_ptr(), dreset.as_ptr() as *mut T, - model_index, + model, i, dim, ) }); } - pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T], model_index: u32) { + pub fn rhs(&self, t: T, yy: &[T], data: &mut [T], rr: &mut [T], model: u32) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); self.check_data_len(data, "data"); @@ -651,7 +651,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, rr.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -666,7 +666,7 @@ impl Compiler { self.has_reset } - pub fn mass(&self, t: T, v: &[T], data: &mut [T], mv: &mut [T], model_index: u32) { + pub fn mass(&self, t: T, v: &[T], data: &mut [T], mv: &mut [T], model: u32) { if !self.has_mass { panic!("Model does not have a mass function"); } @@ -679,7 +679,7 @@ impl Compiler { v.as_ptr(), data.as_ptr() as *mut T, mv.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -704,7 +704,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -721,7 +721,7 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -738,7 +738,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -759,7 +759,7 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -773,7 +773,7 @@ impl Compiler { data: &[T], ddata: &mut [T], dmv: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(dv, "dv"); self.check_state_len(dmv, "dmv"); @@ -792,7 +792,7 @@ impl Compiler { ddata.as_ptr() as *mut T, std::ptr::null(), dmv.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -807,7 +807,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); @@ -826,7 +826,7 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -841,7 +841,7 @@ impl Compiler { ddata: &mut [T], rr: &[T], drr: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(rr, "rr"); @@ -860,14 +860,14 @@ impl Compiler { ddata.as_ptr() as *mut T, rr.as_ptr(), drr.as_ptr() as *mut T, - model_index, + model, i, dim, ) }); } - pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T], model_index: u32) { + pub fn calc_out(&self, t: T, yy: &[T], data: &mut [T], out: &mut [T], model: u32) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); self.check_out_len(out, "out"); @@ -877,7 +877,7 @@ impl Compiler { yy.as_ptr(), data.as_ptr() as *mut T, out.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -894,7 +894,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -911,7 +911,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -928,7 +928,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_state_len(dyy, "dyy"); @@ -949,7 +949,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -964,7 +964,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -983,7 +983,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -998,7 +998,7 @@ impl Compiler { ddata: &mut [T], out: &[T], dout: &mut [T], - model_index: u32, + model: u32, ) { self.check_state_len(yy, "yy"); self.check_data_len(data, "data"); @@ -1017,7 +1017,7 @@ impl Compiler { ddata.as_ptr() as *mut T, out.as_ptr(), dout.as_ptr() as *mut T, - model_index, + model, i, dim, ) @@ -1566,12 +1566,12 @@ mod tests { generate_tests!(test_out_depends_on_internal_tensor); - generate_tests!(test_model_index_n_depends_on_model_index); - generate_tests!(test_model_index_n_dynamic_index_grad); - generate_tests!(test_model_index_n_dynamic_range_width_const); + generate_tests!(test_model_n_depends_on_model); + generate_tests!(test_model_n_dynamic_index_grad); + generate_tests!(test_model_n_dynamic_range_width_const); #[allow(dead_code)] - fn test_model_index_n_depends_on_model_index< + fn test_model_n_depends_on_model< M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq, >() { @@ -1581,9 +1581,9 @@ mod tests { // - tensor indexing can use expression indices, e.g. amp_i[N % 2] // // Expected behavior once implemented: - // - N is taken from model_index. - // - model_index = 0 => N % 2 = 0 - // - model_index = 1 => N % 2 = 1 + // - N is taken from model. + // - model = 0 => N % 2 = 0 + // - model = 1 => N % 2 = 1 let full_text = " amp_i { 0, 10 } dur_i { 10, 5 } @@ -1657,7 +1657,7 @@ mod tests { } #[allow(dead_code)] - fn test_model_index_n_dynamic_index_grad< + fn test_model_n_dynamic_index_grad< M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq, >() { @@ -1737,7 +1737,7 @@ mod tests { } #[allow(dead_code)] - fn test_model_index_n_dynamic_range_width_const< + fn test_model_n_dynamic_range_width_const< M: CodegenModuleCompile + CodegenModuleJit, T: Scalar + RelativeEq, >() { diff --git a/diffsl/src/execution/cranelift/codegen.rs b/diffsl/src/execution/cranelift/codegen.rs index cb5f346..c97e666 100644 --- a/diffsl/src/execution/cranelift/codegen.rs +++ b/diffsl/src/execution/cranelift/codegen.rs @@ -224,7 +224,7 @@ impl CraneliftModule { "ddata", "out", "dout", - "model_index", + "model", "threadId", "threadDim", ]; @@ -286,7 +286,7 @@ impl CraneliftModule { "ddata", "rr", "drr", - "model_index", + "model", "threadId", "threadDim", ]; @@ -345,7 +345,7 @@ impl CraneliftModule { "ddata", "reset", "dreset", - "model_index", + "model", "threadId", "threadDim", ]; @@ -444,7 +444,7 @@ impl CraneliftModule { self.int_type, self.int_type, ]; - let arg_names = &["u0", "du0", "data", "ddata", "model_index", "threadId", "threadDim"]; + let arg_names = &["u0", "du0", "data", "ddata", "model", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -599,7 +599,7 @@ impl CraneliftModule { self.int_type, self.int_type, ]; - let arg_names = &["u0", "data", "model_index", "threadId", "threadDim"]; + let arg_names = &["u0", "data", "model", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -641,7 +641,7 @@ impl CraneliftModule { self.int_type, self.int_type, ]; - let arg_names = &["t", "u", "data", "out", "model_index", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "out", "model", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -691,7 +691,7 @@ impl CraneliftModule { self.int_type, self.int_type, ]; - let arg_names = &["t", "u", "data", "root", "model_index", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "root", "model", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -741,7 +741,7 @@ impl CraneliftModule { self.int_type, self.int_type, ]; - let arg_names = &["t", "u", "data", "reset", "model_index", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "reset", "model", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -789,7 +789,7 @@ impl CraneliftModule { self.int_type, self.int_type, ]; - let arg_names = &["t", "u", "data", "rr", "model_index", "threadId", "threadDim"]; + let arg_names = &["t", "u", "data", "rr", "model", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -833,7 +833,7 @@ impl CraneliftModule { self.int_type, self.int_type, ]; - let arg_names = &["t", "dudt", "data", "rr", "model_index", "threadId", "threadDim"]; + let arg_names = &["t", "dudt", "data", "rr", "model", "threadId", "threadDim"]; { let mut codegen = CraneliftCodeGen::new(self, model, arg_names, arg_types); @@ -1328,13 +1328,13 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { } let var = self .variables - .get("model_index") - .ok_or_else(|| anyhow!("N used where model_index is unavailable"))?; - let model_index = self.builder.use_var(*var); + .get("model") + .ok_or_else(|| anyhow!("N used where model is unavailable"))?; + let model = self.builder.use_var(*var); return Ok(self .builder .ins() - .fcvt_from_sint(self.real_type, model_index)); + .fcvt_from_sint(self.real_type, model)); } let ptr = if iname.is_tangent { // tangent of a constant is zero @@ -1556,8 +1556,8 @@ impl<'ctx, M: Module> CraneliftCodeGen<'ctx, M> { if iname.name == "N" { let var = self .variables - .get("model_index") - .ok_or_else(|| anyhow!("N used where model_index is unavailable"))?; + .get("model") + .ok_or_else(|| anyhow!("N used where model is unavailable"))?; Ok(self.builder.use_var(*var)) } else { Err(anyhow!( diff --git a/diffsl/src/execution/external/mod.rs b/diffsl/src/execution/external/mod.rs index ca1ae44..d1edd34 100644 --- a/diffsl/src/execution/external/mod.rs +++ b/diffsl/src/execution/external/mod.rs @@ -21,7 +21,7 @@ macro_rules! define_symbol_module { pub fn set_u0( u: *mut $ty, data: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -31,7 +31,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, reset: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -44,7 +44,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -57,7 +57,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -69,7 +69,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -81,7 +81,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, reset: *const $ty, dreset: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -91,7 +91,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, rr: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -104,7 +104,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -117,7 +117,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -129,7 +129,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -141,7 +141,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, rr: *const $ty, drr: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -151,7 +151,7 @@ macro_rules! define_symbol_module { v: *const $ty, data: *mut $ty, mv: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -164,7 +164,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, mv: *const $ty, dmv: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -174,7 +174,7 @@ macro_rules! define_symbol_module { du: *mut $ty, data: *const $ty, ddata: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -184,7 +184,7 @@ macro_rules! define_symbol_module { du: *mut $ty, data: *const $ty, ddata: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -194,7 +194,7 @@ macro_rules! define_symbol_module { du: *mut $ty, data: *const $ty, ddata: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -204,7 +204,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, out: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -217,7 +217,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -230,7 +230,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -242,7 +242,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -254,7 +254,7 @@ macro_rules! define_symbol_module { ddata: *mut $ty, out: *const $ty, dout: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -264,7 +264,7 @@ macro_rules! define_symbol_module { u: *const $ty, data: *mut $ty, root: *mut $ty, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); diff --git a/diffsl/src/execution/interface.rs b/diffsl/src/execution/interface.rs index ce0a4d9..c259ebc 100644 --- a/diffsl/src/execution/interface.rs +++ b/diffsl/src/execution/interface.rs @@ -12,7 +12,7 @@ pub type StopFunc = unsafe extern "C" fn( u: *const T, data: *mut T, root: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -21,7 +21,7 @@ pub type ResetFunc = unsafe extern "C" fn( u: *const T, data: *mut T, reset: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -33,7 +33,7 @@ pub type ResetGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -45,7 +45,7 @@ pub type ResetRevGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -56,7 +56,7 @@ pub type ResetSensGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -67,7 +67,7 @@ pub type ResetSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, reset: *const T, dreset: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -76,7 +76,7 @@ pub type RhsFunc = unsafe extern "C" fn( u: *const T, data: *mut T, rr: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -88,7 +88,7 @@ pub type RhsGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -100,7 +100,7 @@ pub type RhsRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -111,7 +111,7 @@ pub type RhsSensGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -122,7 +122,7 @@ pub type RhsSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, rr: *const T, drr: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -131,7 +131,7 @@ pub type MassFunc = unsafe extern "C" fn( v: *const T, data: *mut T, mv: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -143,14 +143,14 @@ pub type MassRevGradFunc = unsafe extern "C" fn( ddata: *mut T, mv: *const T, dmv: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); pub type U0Func = unsafe extern "C" fn( u: *mut T, data: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -159,7 +159,7 @@ pub type U0SensGradFunc = unsafe extern "C" fn( du: *mut T, data: *const T, ddata: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -168,7 +168,7 @@ pub type U0GradFunc = unsafe extern "C" fn( du: *mut T, data: *const T, ddata: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -177,7 +177,7 @@ pub type U0RevGradFunc = unsafe extern "C" fn( du: *mut T, data: *const T, ddata: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -186,7 +186,7 @@ pub type CalcOutFunc = unsafe extern "C" fn( u: *const T, data: *mut T, out: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -198,7 +198,7 @@ pub type CalcOutGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -210,7 +210,7 @@ pub type CalcOutRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -221,7 +221,7 @@ pub type CalcOutSensGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); @@ -232,7 +232,7 @@ pub type CalcOutSensRevGradFunc = unsafe extern "C" fn( ddata: *mut T, out: *const T, dout: *mut T, - model_index: UIntType, + model: UIntType, thread_id: UIntType, thread_dim: UIntType, ); diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 474e38a..a406df9 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -2681,11 +2681,11 @@ impl<'ctx> CodeGen<'ctx> { if iname.is_tangent { return Ok(self.real_type.const_float(0.0)); } - let model_index = self - .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + let model = self + .build_load(self.int_type, *self.get_param("model"), "model")? .into_int_value(); let n_value = self.builder.build_signed_int_to_float( - model_index, + model, self.real_type, "n_as_real", )?; @@ -2896,7 +2896,7 @@ impl<'ctx> CodeGen<'ctx> { AstKind::Name(iname) => { if iname.name == "N" { Ok(self - .build_load(self.int_type, *self.get_param("model_index"), name)? + .build_load(self.int_type, *self.get_param("model"), name)? .into_int_value()) } else { Err(anyhow!( @@ -2963,8 +2963,8 @@ impl<'ctx> CodeGen<'ctx> { .into_float_value(); let u = *self.get_param("u"); let data = *self.get_param("data"); - let model_index = self - .build_load(self.int_type, *self.get_param("model_index"), "model_index")? + let model = self + .build_load(self.int_type, *self.get_param("model"), "model")? .into_int_value(); let thread_id = self .build_load(self.int_type, *self.get_param("thread_id"), "thread_id")? @@ -2980,7 +2980,7 @@ impl<'ctx> CodeGen<'ctx> { t.into(), u.into(), data.into(), - model_index.into(), + model.into(), thread_id.into(), thread_dim.into(), barrier_start.into(), @@ -3069,7 +3069,7 @@ impl<'ctx> CodeGen<'ctx> { code: Option<&str>, ) -> Result> { self.clear(); - let fn_arg_names = &["u0", "data", "model_index", "thread_id", "thread_dim"]; + let fn_arg_names = &["u0", "data", "model", "thread_id", "thread_dim"]; let function = self.add_function( "set_u0", fn_arg_names, @@ -3154,7 +3154,7 @@ impl<'ctx> CodeGen<'ctx> { "u", "data", "out", - "model_index", + "model", "thread_id", "thread_dim", ]; @@ -3274,7 +3274,7 @@ impl<'ctx> CodeGen<'ctx> { "t", "u", "data", - "model_index", + "model", "thread_id", "thread_dim", "barrier_start", @@ -3372,7 +3372,7 @@ impl<'ctx> CodeGen<'ctx> { "u", "data", "root", - "model_index", + "model", "thread_id", "thread_dim", ]; @@ -3472,7 +3472,7 @@ impl<'ctx> CodeGen<'ctx> { "u", "data", "reset", - "model_index", + "model", "thread_id", "thread_dim", ]; @@ -3569,7 +3569,7 @@ impl<'ctx> CodeGen<'ctx> { "u", "data", "rr", - "model_index", + "model", "thread_id", "thread_dim", ]; @@ -3673,7 +3673,7 @@ impl<'ctx> CodeGen<'ctx> { "dudt", "data", "rr", - "model_index", + "model", "thread_id", "thread_dim", ]; diff --git a/diffsl/tests/support/external_test_macros.rs b/diffsl/tests/support/external_test_macros.rs index 69e2e3d..1b9f640 100644 --- a/diffsl/tests/support/external_test_macros.rs +++ b/diffsl/tests/support/external_test_macros.rs @@ -19,7 +19,7 @@ macro_rules! define_external_test { pub unsafe extern "C" fn set_u0( u: *mut $ty, _data: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -34,7 +34,7 @@ macro_rules! define_external_test { u: *const $ty, data: *mut $ty, rr: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -55,7 +55,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -78,7 +78,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -99,7 +99,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -119,7 +119,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _rr: *const $ty, drr: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -137,7 +137,7 @@ macro_rules! define_external_test { v: *const $ty, _data: *mut $ty, mv: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -156,7 +156,7 @@ macro_rules! define_external_test { _ddata: *mut $ty, _mv: *const $ty, dmv: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -172,7 +172,7 @@ macro_rules! define_external_test { _du: *mut $ty, _data: *const $ty, _ddata: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -184,7 +184,7 @@ macro_rules! define_external_test { _du: *mut $ty, _data: *const $ty, _ddata: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -196,7 +196,7 @@ macro_rules! define_external_test { _du: *mut $ty, _data: *const $ty, _ddata: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -208,7 +208,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, out: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -227,7 +227,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -247,7 +247,7 @@ macro_rules! define_external_test { _ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -265,7 +265,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -285,7 +285,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _out: *const $ty, dout: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -303,7 +303,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, root: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -319,7 +319,7 @@ macro_rules! define_external_test { u: *const $ty, _data: *mut $ty, reset: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -338,7 +338,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -358,7 +358,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -377,7 +377,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { @@ -396,7 +396,7 @@ macro_rules! define_external_test { ddata: *mut $ty, _reset: *const $ty, dreset: *mut $ty, - _model_index: u32, + _model: u32, _thread_id: u32, _thread_dim: u32, ) { From 641de3ccae4b947ebcaa67fd67bed5aceb9cbb2d Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Sun, 15 Mar 2026 20:46:30 +0000 Subject: [PATCH 4/4] fix benches --- diffsl/benches/evaluation.rs | 6 +++--- diffsl/benches/pybamm_dfn.rs | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/diffsl/benches/evaluation.rs b/diffsl/benches/evaluation.rs index 380282a..ca98391 100644 --- a/diffsl/benches/evaluation.rs +++ b/diffsl/benches/evaluation.rs @@ -55,14 +55,14 @@ fn execute( let n = N; let compiler = setup::(n, f_text, "execute"); let mut data = compiler.get_new_data(); - compiler.set_inputs(&[], data.as_mut_slice(), 0); + compiler.set_inputs(&[], data.as_mut_slice()); let mut u = vec![1.0; n]; - compiler.set_u0(u.as_mut_slice(), data.as_mut_slice()); + compiler.set_u0(u.as_mut_slice(), data.as_mut_slice(), 0); let mut rr = vec![0.0; n]; let t = 0.0; bencher.bench_local(|| { - compiler.rhs(t, &u, &mut data, &mut rr); + compiler.rhs(t, &u, &mut data, &mut rr, 0); }); } diff --git a/diffsl/benches/pybamm_dfn.rs b/diffsl/benches/pybamm_dfn.rs index a830878..a662088 100644 --- a/diffsl/benches/pybamm_dfn.rs +++ b/diffsl/benches/pybamm_dfn.rs @@ -36,6 +36,7 @@ fn pybamm_dfn_execute_rhs_grad(bench ddata.as_mut_slice(), rr.as_mut_slice(), drr.as_mut_slice(), + 0, ); }); } @@ -69,7 +70,7 @@ fn pybamm_dfn_execute_rhs(bencher: B let mut data = compiler.get_new_data(); let mut rr = vec![0.0; n_states]; bencher.bench_local(move || { - compiler.rhs(t, y.as_slice(), data.as_mut_slice(), rr.as_mut_slice()); + compiler.rhs(t, y.as_slice(), data.as_mut_slice(), rr.as_mut_slice(), 0); }); }