From 5288989f27ab250bd53ac55520dba7f0d9dac081 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Mon, 6 Apr 2026 21:48:52 +0000 Subject: [PATCH] fix: sgrad for reset and stop --- diffsl/src/execution/llvm/codegen.rs | 58 +++++++++++++++++++++++----- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/diffsl/src/execution/llvm/codegen.rs b/diffsl/src/execution/llvm/codegen.rs index 24fb3bd..652bc0d 100644 --- a/diffsl/src/execution/llvm/codegen.rs +++ b/diffsl/src/execution/llvm/codegen.rs @@ -289,8 +289,10 @@ impl CodegenModuleCompile for LlvmModule { let mut module = Self::new(triple, model, threaded, real_type, options.debug)?; let set_u0 = module.codegen_mut().compile_set_u0(model, code)?; - let calc_stop = module.codegen_mut().compile_calc_stop(model, code)?; - let reset = module.codegen_mut().compile_reset(model, code)?; + let calc_stop = module.codegen_mut().compile_calc_stop(model, false, code)?; + let calc_stop_full = module.codegen_mut().compile_calc_stop(model, true, code)?; + let reset = module.codegen_mut().compile_reset(model, false, code)?; + let reset_full = module.codegen_mut().compile_reset(model, true, code)?; let rhs = module.codegen_mut().compile_rhs(model, false, code)?; let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?; let mass = module.codegen_mut().compile_mass(model, code)?; @@ -504,7 +506,7 @@ impl CodegenModuleCompile for LlvmModule { )?; module.codegen_mut().compile_gradient( - reset, + reset_full, &[ CompileGradientArgType::Const, CompileGradientArgType::Const, @@ -518,7 +520,7 @@ impl CodegenModuleCompile for LlvmModule { )?; module.codegen_mut().compile_gradient( - calc_stop, + calc_stop_full, &[ CompileGradientArgType::Const, CompileGradientArgType::Const, @@ -585,7 +587,7 @@ impl CodegenModuleCompile for LlvmModule { )?; module.codegen_mut().compile_gradient( - reset, + reset_full, &[ CompileGradientArgType::Const, CompileGradientArgType::Const, @@ -599,7 +601,7 @@ impl CodegenModuleCompile for LlvmModule { )?; module.codegen_mut().compile_gradient( - calc_stop, + calc_stop_full, &[ CompileGradientArgType::Const, CompileGradientArgType::Const, @@ -3371,6 +3373,7 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_calc_stop<'m>( &mut self, model: &'m DiscreteModel, + include_constants: bool, code: Option<&str>, ) -> Result> { let time_dep_fn = self.ensure_time_dep_fn(model, code)?; @@ -3378,8 +3381,13 @@ impl<'ctx> CodeGen<'ctx> { let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"]; + let function_name = if include_constants { + "calc_stop_full" + } else { + "calc_stop" + }; let function = self.add_function( - "calc_stop", + function_name, fn_arg_names, &[ self.real_type.into(), @@ -3415,11 +3423,23 @@ impl<'ctx> CodeGen<'ctx> { if let Some(stop) = model.stop() { // calculate time dependant definitions let mut nbarriers = 0; - let total_barriers = (model.time_dep_defns().len() + let mut total_barriers = (model.time_dep_defns().len() + model.state_dep_defns().len() + model.state_dep_post_f_defns().len() + 1) as u64; + if include_constants { + total_barriers += model.input_dep_defns().len() as u64; + } let total_barriers_val = self.int_type.const_int(total_barriers, false); + if include_constants { + // calculate time independent definitions + for tensor in model.input_dep_defns() { + self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?; + let barrier_num = self.int_type.const_int(nbarriers + 1, false); + self.jit_compile_call_barrier(barrier_num, total_barriers_val); + nbarriers += 1; + } + } if !model.time_dep_defns().is_empty() { self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?; nbarriers += model.time_dep_defns().len() as u64; @@ -3456,6 +3476,7 @@ impl<'ctx> CodeGen<'ctx> { pub fn compile_reset<'m>( &mut self, model: &'m DiscreteModel, + include_constants: bool, code: Option<&str>, ) -> Result> { let time_dep_fn = self.ensure_time_dep_fn(model, code)?; @@ -3463,8 +3484,13 @@ impl<'ctx> CodeGen<'ctx> { let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?; self.clear(); let fn_arg_names = &["t", "u", "data", "reset", "thread_id", "thread_dim"]; + let function_name = if include_constants { + "reset_full" + } else { + "reset" + }; let function = self.add_function( - "reset", + function_name, fn_arg_names, &[ self.real_type.into(), @@ -3498,11 +3524,23 @@ impl<'ctx> CodeGen<'ctx> { if let Some(reset) = model.reset() { let mut nbarriers = 0; - let total_barriers = (model.time_dep_defns().len() + let mut total_barriers = (model.time_dep_defns().len() + model.state_dep_defns().len() + model.state_dep_post_f_defns().len() + 1) as u64; + if include_constants { + total_barriers += model.input_dep_defns().len() as u64; + } let total_barriers_val = self.int_type.const_int(total_barriers, false); + if include_constants { + // calculate time independent definitions + for tensor in model.input_dep_defns() { + self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?; + let barrier_num = self.int_type.const_int(nbarriers + 1, false); + self.jit_compile_call_barrier(barrier_num, total_barriers_val); + nbarriers += 1; + } + } if !model.time_dep_defns().is_empty() { self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?; nbarriers += model.time_dep_defns().len() as u64;