Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ext/src/helpers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ mod static_id;
mod symbol_enum;
mod tmplock;

pub use nogvl::nogvl;
pub use nogvl::{nogvl, with_gvl};
pub use output_limited_buffer::OutputLimitedBuffer;
pub use static_id::StaticId;
pub use symbol_enum::SymbolEnum;
Expand Down
65 changes: 49 additions & 16 deletions ext/src/helpers/nogvl.rs
Original file line number Diff line number Diff line change
@@ -1,29 +1,62 @@
use std::{ffi::c_void, mem::MaybeUninit, ptr::null_mut};
use std::{
ffi::{c_int, c_void},
mem::MaybeUninit,
panic::{self, AssertUnwindSafe},
ptr::null_mut,
thread,
};

use rb_sys::rb_thread_call_without_gvl;
use rb_sys::{rb_thread_call_with_gvl, rb_thread_call_without_gvl};

unsafe extern "C" fn call_without_gvl<F, R>(arg: *mut c_void) -> *mut c_void
extern "C" {
fn ruby_thread_has_gvl_p() -> c_int;
}

#[inline]
fn has_gvl() -> bool {
unsafe { ruby_thread_has_gvl_p() != 0 }
}

unsafe extern "C" fn call_trampoline<F, R>(arg: *mut c_void) -> *mut c_void
where
F: FnMut() -> R,
R: Sized,
F: FnOnce() -> R,
{
let arg = arg as *mut (&mut F, &mut MaybeUninit<R>);
let (func, result) = unsafe { &mut *arg };
result.write(func());

let data = unsafe { &mut *(arg as *mut (Option<F>, MaybeUninit<thread::Result<R>>)) };
let func = data.0.take().expect("closure called more than once");
data.1.write(panic::catch_unwind(AssertUnwindSafe(func)));
Comment on lines +24 to +26
null_mut()
}

pub fn nogvl<F, R>(mut func: F) -> R
pub fn nogvl<F, R>(func: F) -> R
where
F: FnOnce() -> R,
{
let mut data: (Option<F>, MaybeUninit<thread::Result<R>>) = (Some(func), MaybeUninit::uninit());
let arg = &mut data as *mut _ as *mut c_void;

unsafe {
rb_thread_call_without_gvl(Some(call_trampoline::<F, R>), arg, None, null_mut());
data.1
.assume_init()
.unwrap_or_else(|e| panic::resume_unwind(e))
}
}

pub fn with_gvl<F, R>(func: F) -> R
where
F: FnMut() -> R,
R: Sized,
F: FnOnce() -> R,
{
let result = MaybeUninit::uninit();
let arg_ptr = &(&mut func, &result) as *const _ as *mut c_void;
if has_gvl() {
return func();
}

let mut data: (Option<F>, MaybeUninit<thread::Result<R>>) = (Some(func), MaybeUninit::uninit());
let arg = &mut data as *mut _ as *mut c_void;

unsafe {
rb_thread_call_without_gvl(Some(call_without_gvl::<F, R>), arg_ptr, None, null_mut());
result.assume_init()
rb_thread_call_with_gvl(Some(call_trampoline::<F, R>), arg);
data.1
.assume_init()
.unwrap_or_else(|e| panic::resume_unwind(e))
}
}
27 changes: 21 additions & 6 deletions ext/src/ruby_api/externals.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,16 @@ use super::{
store::StoreContextValue,
table::{Table, TableType},
};
use crate::{conversion_err, not_implemented};
use crate::{conversion_err, define_rb_intern, not_implemented};
use magnus::{
class, gc::Marker, method, prelude::*, rb_sys::AsRawValue, typed_data::Obj, DataTypeFunctions,
Error, Module, RClass, Ruby, TypedData, Value,
class, gc::Marker, method, prelude::*, rb_sys::AsRawValue, scan_args, typed_data::Obj,
DataTypeFunctions, Error, Module, RClass, Ruby, TypedData, Value,
};

define_rb_intern!(
GVL => "gvl",
);

#[derive(TypedData)]
#[magnus(
class = "Wasmtime::ExternType",
Expand Down Expand Up @@ -128,10 +132,21 @@ impl Extern<'_> {
/// @yard
/// Returns the exported function or raises a `{ConversionError}` when the export is not a
/// function.
///
/// @def to_func(gvl: true)
/// @param gvl [Boolean] When +false+, releases the GVL during the call so other Ruby threads run in parallel (each thread must use its own {Store}). Defaults to +true+.
Comment thread
omohokcoj marked this conversation as resolved.
///
/// Failing to respect the {Store}-per-thread requirement, when using `gvl: false` is highly unsafe and will result in undefined behavior.
/// @return [Func] The exported function.
pub fn to_func(ruby: &Ruby, rb_self: Obj<Self>) -> Result<Value, Error> {
pub fn to_func(ruby: &Ruby, rb_self: Obj<Self>, args: &[Value]) -> Result<Value, Error> {
let args = scan_args::scan_args::<(), (), (), (), _, ()>(args)?;
let kw = scan_args::get_kwargs::<_, (), (Option<bool>,), ()>(args.keywords, &[], &[*GVL])?;

match *rb_self {
Extern::Func(f) => Ok(f.as_value()),
Extern::Func(f) => match kw.optional.0 {
Some(false) => Ok(ruby.obj_wrap(f.without_gvl()).as_value()),
_ => Ok(f.as_value()),
},
_ => conversion_err!(Self::inner_class(rb_self), Func::class(ruby)),
}
}
Expand Down Expand Up @@ -252,7 +267,7 @@ pub fn init(ruby: &Ruby) -> Result<(), Error> {

let class = root().define_class("Extern", ruby.class_object())?;

class.define_method("to_func", method!(Extern::to_func, 0))?;
class.define_method("to_func", method!(Extern::to_func, -1))?;
class.define_method("to_global", method!(Extern::to_global, 0))?;
class.define_method("to_memory", method!(Extern::to_memory, 0))?;
class.define_method("to_table", method!(Extern::to_table, 0))?;
Expand Down
171 changes: 97 additions & 74 deletions ext/src/ruby_api/func.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ use super::{
root,
store::{Store, StoreContextValue, StoreData},
};
use crate::{error, Caller};
use crate::{
error,
helpers::{nogvl, with_gvl},
Caller,
};
use magnus::{
block::Proc, class, function, gc::Marker, method, prelude::*, scan_args::scan_args,
typed_data::Obj, value::Opaque, DataTypeFunctions, Error, IntoValue, Object, RArray, Ruby,
Expand Down Expand Up @@ -81,6 +85,7 @@ impl From<&FuncType> for wasmtime::ExternType {
pub struct Func<'a> {
store: StoreContextValue<'a>,
inner: FuncImpl,
gvl: bool,
}

impl DataTypeFunctions for Func<'_> {
Expand Down Expand Up @@ -141,14 +146,23 @@ impl<'a> Func<'a> {
let func_closure = make_func_closure(&ty, callable.into());
let inner = wasmtime::Func::new(context, ty, func_closure);

Ok(Self {
store: store.into(),
inner,
})
Ok(Self::from_inner(store.into(), inner))
}

pub fn from_inner(store: StoreContextValue<'a>, inner: FuncImpl) -> Self {
Self { store, inner }
Self {
store,
inner,
gvl: true,
}
}

pub fn without_gvl(&self) -> Self {
Self {
store: self.store,
inner: self.inner,
gvl: false,
}
}

pub fn get(&self) -> FuncImpl {
Expand Down Expand Up @@ -176,7 +190,7 @@ impl<'a> Func<'a> {
/// func.call(1, 2) # => [2, 3]
pub fn call(&self, args: &[Value]) -> Result<Value, Error> {
let ruby = Ruby::get().unwrap();
Self::invoke(&ruby, &self.store, &self.inner, args)
Self::invoke(&ruby, &self.store, &self.inner, self.gvl, args)
}

pub fn inner(&self) -> &FuncImpl {
Expand Down Expand Up @@ -211,15 +225,20 @@ impl<'a> Func<'a> {
ruby: &Ruby,
store: &StoreContextValue,
func: &wasmtime::Func,
gvl: bool,
args: &[Value],
) -> Result<Value, Error> {
let mut context = store.context_mut()?;
let func_ty = func.ty(&mut context);
let params = Params::new(ruby, &func_ty, args)?.to_vec(ruby, store)?;
let mut results = vec![Val::null_func_ref(); func_ty.results().len()];

func.call(context, &params, &mut results)
.map_err(|e| store.handle_wasm_error(ruby, e))?;
let call_result = if gvl {
func.call(&mut context, &params, &mut results)
} else {
nogvl(|| func.call(&mut context, &params, &mut results))
};
call_result.map_err(|e| store.handle_wasm_error(ruby, e))?;

// Check for any errors stored during execution (e.g., from socket checks)
if let Some(error) = store.take_last_error()? {
Expand Down Expand Up @@ -275,79 +294,83 @@ pub fn make_func_closure(
// We then return a generic error here. The caller will check for a stored error
// and raise it if it exists.
move |caller_impl: CallerImpl<'_, StoreData>, params: &[Val], results: &mut [Val]| {
let ruby = Ruby::get().unwrap();
let wrapped_caller = ruby.obj_wrap(Caller::new(caller_impl));
let store_context = StoreContextValue::from(wrapped_caller);

let rparams = ruby.ary_new_capa(params.len() + 1);
rparams
.push(wrapped_caller.as_value())
.map_err(|e| wasmtime::Error::msg(format!("failed to push caller: {e}")))?;

for (i, param) in params.iter().enumerate() {
let val = param
.to_ruby_value(&ruby, &store_context)
.map_err(|e| wasmtime::Error::msg(format!("invalid argument at index {i}: {e}")))?;
rparams.push(val).map_err(|e| {
wasmtime::Error::msg(format!("failed to push argument at index {i}: {e}"))
})?;
}
// Borrow `ty` so the `move` closure captures a reference, not the owned value.
let ty = &ty;
with_gvl(move || {
Comment on lines +297 to +299
let ruby = Ruby::get().unwrap();
let wrapped_caller = ruby.obj_wrap(Caller::new(caller_impl));
let store_context = StoreContextValue::from(wrapped_caller);

let rparams = ruby.ary_new_capa(params.len() + 1);
rparams
.push(wrapped_caller.as_value())
.map_err(|e| wasmtime::Error::msg(format!("failed to push caller: {e}")))?;

for (i, param) in params.iter().enumerate() {
let val = param.to_ruby_value(&ruby, &store_context).map_err(|e| {
wasmtime::Error::msg(format!("invalid argument at index {i}: {e}"))
})?;
rparams.push(val).map_err(|e| {
wasmtime::Error::msg(format!("failed to push argument at index {i}: {e}"))
})?;
}

let callable = ruby.get_inner(callable);
let callable = ruby.get_inner(callable);

match (callable.call(rparams), results.len()) {
(Ok(_proc_result), 0) => {
wrapped_caller.expire();
Ok(())
}
(Ok(proc_result), n) => {
// For len=1, accept both `val` and `[val]`
let Ok(proc_result) = RArray::to_ary(proc_result) else {
return result_error!(
store_context,
wrapped_caller,
format!("could not convert {} to results array", callable)
);
};

if proc_result.len() != results.len() {
return result_error!(
store_context,
wrapped_caller,
format!(
"wrong number of results (given {}, expected {}) in {}",
proc_result.len(),
n,
callable
)
);
match (callable.call(rparams), results.len()) {
(Ok(_proc_result), 0) => {
wrapped_caller.expire();
Ok(())
}
(Ok(proc_result), n) => {
// For len=1, accept both `val` and `[val]`
let Ok(proc_result) = RArray::to_ary(proc_result) else {
return result_error!(
store_context,
wrapped_caller,
format!("could not convert {} to results array", callable)
);
};

if proc_result.len() != results.len() {
return result_error!(
store_context,
wrapped_caller,
format!(
"wrong number of results (given {}, expected {}) in {}",
proc_result.len(),
n,
callable
)
);
}

for (i, ((rb_val, wasm_val), ty)) in unsafe { proc_result.as_slice() }
.iter()
.zip(results.iter_mut())
.zip(ty.results())
.enumerate()
{
match rb_val.to_wasm_val(&store_context, ty) {
Ok(val) => *wasm_val = val,
Err(e) => {
return result_error!(
store_context,
wrapped_caller,
format!("invalid result at index {i}: {e} in {callable}")
);
for (i, ((rb_val, wasm_val), ty)) in unsafe { proc_result.as_slice() }
.iter()
.zip(results.iter_mut())
.zip(ty.results())
.enumerate()
{
match rb_val.to_wasm_val(&store_context, ty) {
Ok(val) => *wasm_val = val,
Err(e) => {
return result_error!(
store_context,
wrapped_caller,
format!("invalid result at index {i}: {e} in {callable}")
);
}
}
}
}

wrapped_caller.expire();
Ok(())
}
(Err(e), _) => {
caller_error!(store_context, wrapped_caller, e)
wrapped_caller.expire();
Ok(())
}
(Err(e), _) => {
caller_error!(store_context, wrapped_caller, e)
}
}
}
})
}
}

Expand Down
2 changes: 1 addition & 1 deletion ext/src/ruby_api/instance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ impl Instance {
})?)?;

let func = rb_self.get_func(rb_self.store.context_mut(), unsafe { name.as_str()? })?;
Func::invoke(ruby, &rb_self.store.into(), &func, &args[1..])
Func::invoke(ruby, &rb_self.store.into(), &func, true, &args[1..])
}

fn get_func(
Expand Down
Loading
Loading