Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,9 @@ jobs:
path: |
${{ env.CARGO_HOME }}
target
key: unit-test-${{ runner.os }}-${{ matrix.toolchain}}-${{ matrix.llvm[0] }}-${{ matrix.features }}
key: unit-test-${{ runner.os }}-${{ matrix.toolchain }}-${{ matrix.llvm != '' && matrix.llvm[0] || 'none' }}-${{ matrix.features }}-${{ hashFiles('**/Cargo.toml', '**/build.rs', 'rust-toolchain', 'rust-toolchain.toml') }}
restore-keys: |
unit-test-${{ runner.os }}-${{ matrix.toolchain }}-${{ matrix.llvm != '' && matrix.llvm[0] || 'none' }}-${{ matrix.features }}-
- name: Set up Rust
run: rustup default ${{ matrix.toolchain }} && rustup update ${{ matrix.toolchain }} --no-self-update && rustup component add clippy rust-docs
- name: Rust version
Expand Down Expand Up @@ -228,4 +230,4 @@ jobs:
if: ${{ github.ref == 'refs/heads/main' }}
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
publish_dir: ./book/book
publish_dir: ./book/book
57 changes: 50 additions & 7 deletions diffsl/src/execution/llvm/codegen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -288,16 +288,16 @@ impl CodegenModuleCompile for LlvmModule {

let mut module = Self::new(triple, model, threaded, real_type, options.debug)?;

let set_u0 = module.codegen_mut().compile_set_u0(model, code)?;
module.codegen_mut().compile_set_u0(model, code)?;
let _calc_stop = module.codegen_mut().compile_calc_stop(model, code)?;
let rhs = module.codegen_mut().compile_rhs(model, false, code)?;
let rhs_full = module.codegen_mut().compile_rhs(model, true, code)?;
let mass = module.codegen_mut().compile_mass(model, code)?;
let calc_out = module.codegen_mut().compile_calc_out(model, false, code)?;
let calc_out_full = module.codegen_mut().compile_calc_out(model, true, code)?;
module.codegen_mut().compile_rhs(model, false, code)?;
module.codegen_mut().compile_rhs(model, true, code)?;
module.codegen_mut().compile_mass(model, code)?;
module.codegen_mut().compile_calc_out(model, false, code)?;
module.codegen_mut().compile_calc_out(model, true, code)?;
let _set_id = module.codegen_mut().compile_set_id(model)?;
let _get_dims = module.codegen_mut().compile_get_dims(model)?;
let set_inputs = module.codegen_mut().compile_inputs(model, false)?;
module.codegen_mut().compile_inputs(model, false)?;
let _get_inputs = module.codegen_mut().compile_inputs(model, true)?;
let _set_constants = module.codegen_mut().compile_set_constants(model, code)?;
let tensor_info = module
Expand All @@ -320,6 +320,49 @@ impl CodegenModuleCompile for LlvmModule {

module.pre_autodiff_optimisation()?;

// Refresh function handles after optimization passes. Some LLVM passes may
// replace or drop function values, and stale handles can cause UB in
// downstream C++ APIs (Enzyme/LLVM).
let set_u0 = module
.codegen()
.module()
.get_function("set_u0")
.ok_or_else(|| anyhow!("Missing function after pre-autodiff optimization: set_u0"))?;
let rhs = module
.codegen()
.module()
.get_function("rhs")
.ok_or_else(|| anyhow!("Missing function after pre-autodiff optimization: rhs"))?;
let rhs_full = module
.codegen()
.module()
.get_function("rhs_full")
.ok_or_else(|| anyhow!("Missing function after pre-autodiff optimization: rhs_full"))?;
let mass = module
.codegen()
.module()
.get_function("mass")
.ok_or_else(|| anyhow!("Missing function after pre-autodiff optimization: mass"))?;
let calc_out = module
.codegen()
.module()
.get_function("calc_out")
.ok_or_else(|| anyhow!("Missing function after pre-autodiff optimization: calc_out"))?;
let calc_out_full = module
.codegen()
.module()
.get_function("calc_out_full")
.ok_or_else(|| {
anyhow!("Missing function after pre-autodiff optimization: calc_out_full")
})?;
let set_inputs = module
.codegen()
.module()
.get_function("set_inputs")
.ok_or_else(|| {
anyhow!("Missing function after pre-autodiff optimization: set_inputs")
})?;

module.codegen_mut().compile_gradient(
set_u0,
&[
Expand Down
Loading