Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 48 additions & 10 deletions diffsl/src/execution/llvm/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,10 @@ impl CodegenModuleCompile for LlvmModule {
let mut module = Self::new(triple, model, threaded, real_type, options.debug)?;

let set_u0 = module.codegen_mut().compile_set_u0(model, code)?;
let calc_stop = module.codegen_mut().compile_calc_stop(model, code)?;
let reset = module.codegen_mut().compile_reset(model, code)?;
let calc_stop = module.codegen_mut().compile_calc_stop(model, false, code)?;
let calc_stop_full = module.codegen_mut().compile_calc_stop(model, true, code)?;
let reset = module.codegen_mut().compile_reset(model, false, code)?;
let reset_full = module.codegen_mut().compile_reset(model, true, code)?;
let rhs = module.codegen_mut().compile_rhs(model, false, code)?;
let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?;
let mass = module.codegen_mut().compile_mass(model, code)?;
Expand Down Expand Up @@ -504,7 +506,7 @@ impl CodegenModuleCompile for LlvmModule {
)?;

module.codegen_mut().compile_gradient(
reset,
reset_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
Expand All @@ -518,7 +520,7 @@ impl CodegenModuleCompile for LlvmModule {
)?;

module.codegen_mut().compile_gradient(
calc_stop,
calc_stop_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
Expand Down Expand Up @@ -585,7 +587,7 @@ impl CodegenModuleCompile for LlvmModule {
)?;

module.codegen_mut().compile_gradient(
reset,
reset_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
Expand All @@ -599,7 +601,7 @@ impl CodegenModuleCompile for LlvmModule {
)?;

module.codegen_mut().compile_gradient(
calc_stop,
calc_stop_full,
&[
CompileGradientArgType::Const,
CompileGradientArgType::Const,
Expand Down Expand Up @@ -3371,15 +3373,21 @@ impl<'ctx> CodeGen<'ctx> {
pub fn compile_calc_stop<'m>(
&mut self,
model: &'m DiscreteModel,
include_constants: bool,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
self.clear();
let fn_arg_names = &["t", "u", "data", "root", "thread_id", "thread_dim"];
let function_name = if include_constants {
"calc_stop_full"
} else {
"calc_stop"
};
let function = self.add_function(
"calc_stop",
function_name,
fn_arg_names,
&[
self.real_type.into(),
Expand Down Expand Up @@ -3415,11 +3423,23 @@ impl<'ctx> CodeGen<'ctx> {
if let Some(stop) = model.stop() {
// calculate time dependant definitions
let mut nbarriers = 0;
let total_barriers = (model.time_dep_defns().len()
let mut total_barriers = (model.time_dep_defns().len()
+ model.state_dep_defns().len()
+ model.state_dep_post_f_defns().len()
+ 1) as u64;
if include_constants {
total_barriers += model.input_dep_defns().len() as u64;
}
let total_barriers_val = self.int_type.const_int(total_barriers, false);
if include_constants {
// calculate time independent definitions
for tensor in model.input_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
}
if !model.time_dep_defns().is_empty() {
self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
nbarriers += model.time_dep_defns().len() as u64;
Expand Down Expand Up @@ -3456,15 +3476,21 @@ impl<'ctx> CodeGen<'ctx> {
pub fn compile_reset<'m>(
&mut self,
model: &'m DiscreteModel,
include_constants: bool,
code: Option<&str>,
) -> Result<FunctionValue<'ctx>> {
let time_dep_fn = self.ensure_time_dep_fn(model, code)?;
let state_dep_fn = self.ensure_state_dep_fn(model, code)?;
let state_dep_post_f_fn = self.ensure_state_dep_post_f_fn(model, code)?;
self.clear();
let fn_arg_names = &["t", "u", "data", "reset", "thread_id", "thread_dim"];
let function_name = if include_constants {
"reset_full"
} else {
"reset"
};
let function = self.add_function(
"reset",
function_name,
fn_arg_names,
&[
self.real_type.into(),
Expand Down Expand Up @@ -3498,11 +3524,23 @@ impl<'ctx> CodeGen<'ctx> {

if let Some(reset) = model.reset() {
let mut nbarriers = 0;
let total_barriers = (model.time_dep_defns().len()
let mut total_barriers = (model.time_dep_defns().len()
+ model.state_dep_defns().len()
+ model.state_dep_post_f_defns().len()
+ 1) as u64;
if include_constants {
total_barriers += model.input_dep_defns().len() as u64;
}
let total_barriers_val = self.int_type.const_int(total_barriers, false);
if include_constants {
// calculate time independent definitions
for tensor in model.input_dep_defns() {
self.jit_compile_tensor(tensor, Some(*self.get_var(tensor)), code)?;
let barrier_num = self.int_type.const_int(nbarriers + 1, false);
self.jit_compile_call_barrier(barrier_num, total_barriers_val);
nbarriers += 1;
}
}
if !model.time_dep_defns().is_empty() {
self.build_dep_call(time_dep_fn, "time_dep", nbarriers, total_barriers)?;
nbarriers += model.time_dep_defns().len() as u64;
Expand Down
Loading