Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diffsl/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ mmap-rs = "0.7"
[target.'cfg(windows)'.dependencies.windows-sys]
version = "0.61"
features = [
"Win32_Foundation",
"Win32_System_LibraryLoader",
]

Expand Down
9 changes: 1 addition & 8 deletions diffsl/src/execution/external/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -373,14 +373,7 @@ macro_rules! impl_extern_symbols {
($ty:ty, $sym:path) => {
impl ExternSymbols for $ty {
fn insert_symbols(symbols: &mut HashMap<String, *const u8>) {
macro_rules! insert {
($($name:literal => $func:ident,)+) => {
use $sym as sym;
$(symbols.insert($name.to_string(), sym::$func as *const u8);)+
};
}

crate::execution::external_interface::for_each_external_symbol!(insert);
crate::execution::external_interface::insert_external_symbols!(symbols, $sym);
}
}
};
Expand Down
6 changes: 4 additions & 2 deletions diffsl/src/execution/external_dynamic/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ impl DynamicLibrary {
.chain(std::iter::once(0))
.collect::<Vec<_>>();
let handle = unsafe { LoadLibraryW(wide_path.as_ptr()) };
if handle == 0 {
if handle.is_null() {
return Err(anyhow!(
"Failed to load dynamic library {}: {}",
path.display(),
Expand Down Expand Up @@ -129,8 +129,10 @@ impl DynamicLibrary {
#[cfg(all(windows, not(target_arch = "wasm32")))]
impl Drop for DynamicLibrary {
fn drop(&mut self) {
use windows_sys::Win32::Foundation::FreeLibrary;

unsafe {
windows_sys::Win32::System::LibraryLoader::FreeLibrary(self.handle);
FreeLibrary(self.handle);
}
}
}
Expand Down
71 changes: 71 additions & 0 deletions diffsl/src/execution/external_interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,77 @@ macro_rules! for_each_external_symbol {
#[allow(unused_imports)]
pub(crate) use for_each_external_symbol;

#[allow(unused_macros)]
macro_rules! insert_external_symbols {
($symbols:expr, $sym:path) => {{
use $sym as sym;
$symbols.insert("barrier_init".to_string(), sym::barrier_init as *const u8);
$symbols.insert("set_constants".to_string(), sym::set_constants as *const u8);
$symbols.insert("set_u0".to_string(), sym::set_u0 as *const u8);
$symbols.insert("reset".to_string(), sym::reset as *const u8);
$symbols.insert("reset_grad".to_string(), sym::reset_grad as *const u8);
$symbols.insert("reset_rgrad".to_string(), sym::reset_rgrad as *const u8);
$symbols.insert("reset_sgrad".to_string(), sym::reset_sgrad as *const u8);
$symbols.insert("reset_srgrad".to_string(), sym::reset_srgrad as *const u8);
$symbols.insert("rhs".to_string(), sym::rhs as *const u8);
$symbols.insert("rhs_grad".to_string(), sym::rhs_grad as *const u8);
$symbols.insert("rhs_rgrad".to_string(), sym::rhs_rgrad as *const u8);
$symbols.insert("rhs_sgrad".to_string(), sym::rhs_sgrad as *const u8);
$symbols.insert("rhs_srgrad".to_string(), sym::rhs_srgrad as *const u8);
$symbols.insert("mass".to_string(), sym::mass as *const u8);
$symbols.insert("mass_rgrad".to_string(), sym::mass_rgrad as *const u8);
$symbols.insert("set_u0_grad".to_string(), sym::set_u0_grad as *const u8);
$symbols.insert("set_u0_rgrad".to_string(), sym::set_u0_rgrad as *const u8);
$symbols.insert("set_u0_sgrad".to_string(), sym::set_u0_sgrad as *const u8);
$symbols.insert("calc_out".to_string(), sym::calc_out as *const u8);
$symbols.insert("calc_out_grad".to_string(), sym::calc_out_grad as *const u8);
$symbols.insert(
"calc_out_rgrad".to_string(),
sym::calc_out_rgrad as *const u8,
);
$symbols.insert(
"calc_out_sgrad".to_string(),
sym::calc_out_sgrad as *const u8,
);
$symbols.insert(
"calc_out_srgrad".to_string(),
sym::calc_out_srgrad as *const u8,
);
$symbols.insert("calc_stop".to_string(), sym::calc_stop as *const u8);
$symbols.insert(
"calc_stop_grad".to_string(),
sym::calc_stop_grad as *const u8,
);
$symbols.insert(
"calc_stop_rgrad".to_string(),
sym::calc_stop_rgrad as *const u8,
);
$symbols.insert(
"calc_stop_sgrad".to_string(),
sym::calc_stop_sgrad as *const u8,
);
$symbols.insert(
"calc_stop_srgrad".to_string(),
sym::calc_stop_srgrad as *const u8,
);
$symbols.insert("set_id".to_string(), sym::set_id as *const u8);
$symbols.insert("get_dims".to_string(), sym::get_dims as *const u8);
$symbols.insert("set_inputs".to_string(), sym::set_inputs as *const u8);
$symbols.insert("get_inputs".to_string(), sym::get_inputs as *const u8);
$symbols.insert(
"set_inputs_grad".to_string(),
sym::set_inputs_grad as *const u8,
);
$symbols.insert(
"set_inputs_rgrad".to_string(),
sym::set_inputs_rgrad as *const u8,
);
}};
}

#[allow(unused_imports)]
pub(crate) use insert_external_symbols;

macro_rules! collect_external_symbol_names {
($($name:literal => $func:ident,)+) => {
pub(crate) const EXTERNAL_SYMBOL_NAMES: &[&str] = &[$($name),+];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
diffsl = { path = "../../../", features = ["external"] }
diffsl = { path = "../../../", features = ["external_f32"] }

[workspace]
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@ edition = "2021"
crate-type = ["cdylib"]

[dependencies]
diffsl = { path = "../../../", features = ["external"] }
diffsl = { path = "../../../", features = ["external_f64"] }

[workspace]
Loading