Skip to content

Commit b51dda1

Browse files
authored
Make Param be Copy and mark exceptions raised from Ruby (#158)
1 parent 9b7c036 commit b51dda1

14 files changed

Lines changed: 257 additions & 96 deletions

File tree

.github/workflows/memcheck.yml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,18 @@ on:
77
ruby-version:
88
description: "Ruby version to memcheck"
99
required: true
10-
default: "3.1"
10+
default: "3.2"
1111
type: choice
1212
options:
1313
- "head"
1414
- "3.2"
1515
- "3.1"
1616
- "3.0"
17+
debug:
18+
description: "Enable debug mode"
19+
required: false
20+
default: "false"
21+
type: boolean
1722
push:
1823
branches: ["*"]
1924
tags-ignore: ["v*"] # Skip Memcheck for releases
@@ -49,6 +54,8 @@ jobs:
4954
RSPEC_FORMATTER: "progress"
5055
RSPEC_FAILURE_EXIT_CODE: "0"
5156
GC_AT_EXIT: "1"
57+
DEBUG: ${{ inputs.debug || 'false' }}
58+
RB_SYS_CARGO_PROFILE: ${{ inputs.debug == 'true' && 'dev' || 'release' }}
5259
WASMTIME_TARGET: "x86_64-unknown-linux-gnu" # use generic target for memcheck
5360
run: |
5461
if ! bundle exec rake mem:check; then

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

bench/host_call.rb

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,24 @@
22

33
Bench.ips do |x|
44
engine = Wasmtime::Engine.new
5-
mod = Wasmtime::Module.new(engine, <<~WAT)
6-
(module
7-
(import "host" "succ" (func (param i32) (result i32)))
8-
(export "run" (func 0)))
9-
WAT
10-
linker = Wasmtime::Linker.new(engine)
11-
linker.func_new("host", "succ", [:i32], [:i32]) do |_caller, arg1|
12-
arg1.succ
13-
end
5+
[4, 16, 64, 128, 256].each do |n|
6+
result_type_wat = Array.new(n) { |_| :i32 }.join(" ")
7+
mod = Wasmtime::Module.new(engine, <<~WAT)
8+
(module
9+
(import "host" "succ" (func (param i32) (result #{result_type_wat})))
10+
(export "run" (func 0)))
11+
WAT
12+
linker = Wasmtime::Linker.new(engine)
13+
results = Array.new(n) { |_| :i32 }
14+
result_array = Array.new(n) { |i| i }
15+
linker.func_new("host", "succ", [:i32], results) do |_caller, arg1|
16+
result_array
17+
end
1418

15-
x.report("Call host func") do
16-
store = Wasmtime::Store.new(engine)
17-
linker.instantiate(store, mod).invoke("run", 101)
19+
x.report("Call host func (#{n} args)") do
20+
store = Wasmtime::Store.new(engine)
21+
linker.instantiate(store, mod).invoke("run", 101)
22+
end
1823
end
1924

2025
x.compare!

ext/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ anyhow = "*" # Use whatever Wasmtime uses
2727
wat = "1.0.59"
2828
tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "time", "net"], optional = true }
2929
async-timer = { version = "1.0.0-beta.8", features = ["tokio1"], optional = true }
30+
static_assertions = "1.1.0"
3031

3132
[build-dependencies]
3233
rb-sys-env = "0.1.2"

ext/src/ruby_api/convert.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ impl ToRubyValue for Val {
5959
}
6060
}
6161
pub trait ToWasmVal {
62-
fn to_wasm_val(&self, ty: &ValType) -> Result<Val, Error>;
62+
fn to_wasm_val(&self, ty: ValType) -> Result<Val, Error>;
6363
}
6464

6565
impl ToWasmVal for Value {
66-
fn to_wasm_val(&self, ty: &ValType) -> Result<Val, Error> {
66+
fn to_wasm_val(&self, ty: ValType) -> Result<Val, Error> {
6767
match ty {
6868
ValType::I32 => Ok(i32::try_convert(*self)?.into()),
6969
ValType::I64 => Ok(i64::try_convert(*self)?.into()),

ext/src/ruby_api/func.rs

Lines changed: 69 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,21 @@ impl From<&Func<'_>> for wasmtime::Extern {
208208
}
209209
}
210210

211+
macro_rules! caller_error {
212+
($store:expr, $caller:expr, $error:expr) => {{
213+
$store.set_last_error($error);
214+
$caller.expire();
215+
Err(anyhow::anyhow!(""))
216+
}};
217+
}
218+
219+
macro_rules! result_error {
220+
($store:expr, $caller:expr, $msg:expr) => {{
221+
let error = Error::new(result_error(), $msg);
222+
caller_error!($store, $caller, error)
223+
}};
224+
}
225+
211226
pub fn make_func_closure(
212227
ty: &wasmtime::FuncType,
213228
callable: Proc,
@@ -216,6 +231,11 @@ pub fn make_func_closure(
216231
let ty = ty.to_owned();
217232
let callable = ShareableProc(callable);
218233

234+
// The error handling here is a bit tricky. We want to return a Ruby exception,
235+
// but doing so directly can easily cause an early Ruby GC and segfault. So to
236+
// be safe, we store all Ruby errors on the store context so it can be marked.
237+
// We then return a generic error here. The caller will check for a stored error
238+
// and raise it if it exists.
219239
move |caller_impl: CallerImpl<'_, StoreData>, params: &[Val], results: &mut [Val]| {
220240
let wrapped_caller = Obj::wrap(Caller::new(caller_impl));
221241
let store_context = StoreContextValue::from(wrapped_caller);
@@ -232,49 +252,59 @@ pub fn make_func_closure(
232252

233253
let callable = callable.0;
234254

235-
let result = callable
236-
.call(unsafe { rparams.as_slice() })
237-
.and_then(|proc_result| {
238-
match results.len() {
239-
0 => Ok(()), // Ignore return value
240-
n => {
241-
// For len=1, accept both `val` and `[val]`
242-
let proc_result = RArray::to_ary(proc_result)?;
243-
if proc_result.len() != n {
244-
return Err(Error::new(
245-
result_error(),
246-
format!(
247-
"wrong number of results (given {}, expected {}) in {}",
248-
proc_result.len(),
249-
n,
250-
callable,
251-
),
252-
));
253-
}
254-
for (i, ((rb_val, wasm_val), ty)) in unsafe { proc_result.as_slice() }
255-
.iter()
256-
.zip(results.iter_mut())
257-
.zip(ty.results())
258-
.enumerate()
259-
{
260-
*wasm_val = rb_val.to_wasm_val(&ty).map_err(|e| {
261-
Error::new(
262-
result_error(),
263-
format!("{e} (result index {i} in {callable})"),
264-
)
265-
})?;
255+
match (callable.call(unsafe { rparams.as_slice() }), results.len()) {
256+
(Ok(_proc_result), 0) => {
257+
wrapped_caller.get().expire();
258+
Ok(())
259+
}
260+
(Ok(proc_result), n) => {
261+
// For len=1, accept both `val` and `[val]`
262+
let Ok(proc_result) = RArray::to_ary(proc_result) else {
263+
return result_error!(
264+
store_context,
265+
wrapped_caller.get(),
266+
format!("could not convert {} to results array", callable)
267+
);
268+
};
269+
270+
if proc_result.len() != results.len() {
271+
return result_error!(
272+
store_context,
273+
wrapped_caller.get(),
274+
format!(
275+
"wrong number of results (given {}, expected {}) in {}",
276+
proc_result.len(),
277+
n,
278+
callable
279+
)
280+
);
281+
}
282+
283+
for (i, ((rb_val, wasm_val), ty)) in unsafe { proc_result.as_slice() }
284+
.iter()
285+
.zip(results.iter_mut())
286+
.zip(ty.results())
287+
.enumerate()
288+
{
289+
match rb_val.to_wasm_val(ty) {
290+
Ok(val) => *wasm_val = val,
291+
Err(e) => {
292+
return result_error!(
293+
store_context,
294+
wrapped_caller.get(),
295+
format!("invalid result at index {i}: {e} in {callable}")
296+
);
266297
}
267-
Ok(())
268298
}
269299
}
270-
})
271-
.map_err(|e| anyhow::anyhow!(e));
272-
273-
// Drop the wasmtime::Caller so it does not outlive the Func call, if e.g. the user
274-
// assigned the Ruby Wasmtime::Caller instance to a global.
275-
wrapped_caller.get().expire();
276300

277-
result
301+
wrapped_caller.get().expire();
302+
Ok(())
303+
}
304+
(Err(e), _) => {
305+
caller_error!(store_context, wrapped_caller.get(), e)
306+
}
307+
}
278308
}
279309
}
280310

ext/src/ruby_api/global.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ impl<'a> Global<'a> {
5555
mutability: Mutability,
5656
) -> Result<Self, Error> {
5757
let wasm_type = value_type.to_val_type()?;
58-
let wasm_default = default.to_wasm_val(&wasm_type)?;
58+
let wasm_default = default.to_wasm_val(wasm_type.clone())?;
5959
let store = s.get();
6060
let inner = GlobalImpl::new(
6161
store.context_mut(),
@@ -116,7 +116,7 @@ impl<'a> Global<'a> {
116116
self.inner
117117
.set(
118118
self.store.context_mut()?,
119-
value.to_wasm_val(&self.value_type()?)?,
119+
value.to_wasm_val(self.value_type()?)?,
120120
)
121121
.map_err(|e| error!("{}", e))
122122
.and_then(|result| {

ext/src/ruby_api/params.rs

Lines changed: 76 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,41 @@
11
use super::convert::ToWasmVal;
22
use magnus::{exception::arg_error, Error, ExceptionClass, Value};
3+
use static_assertions::assert_eq_size;
34
use wasmtime::{FuncType, ValType};
45

5-
#[derive(Debug)]
6-
struct Param<'a> {
7-
index: usize,
8-
ty: ValType,
9-
val: &'a Value,
6+
#[derive(Debug, Copy, Clone)]
7+
#[repr(C)]
8+
struct Param {
9+
val: Value,
10+
index: u32,
11+
ty: ValTypeCopy,
1012
}
13+
// Keep `Param` small so copying it to the stack is cheap, typically anything
14+
// less than 3usize is good
15+
assert_eq_size!(Param, [u64; 2]);
1116

12-
impl<'a> Param<'a> {
13-
pub fn new(index: usize, ty: ValType, val: &'a Value) -> Self {
14-
Self { index, ty, val }
17+
impl Param {
18+
pub fn new(index: u32, ty: ValType, val: Value) -> Self {
19+
Self {
20+
index,
21+
ty: ty.into(),
22+
val,
23+
}
1524
}
1625

17-
fn to_wasmtime_val(&self) -> Result<wasmtime::Val, Error> {
18-
self.val.to_wasm_val(&self.ty).map_err(|error| match error {
19-
Error::Error(class, msg) => {
20-
Error::new(class, format!("{} (param index {}) ", msg, self.index))
21-
}
22-
Error::Exception(exception) => Error::new(
23-
ExceptionClass::from_value(exception.class().into()).unwrap_or_else(arg_error),
24-
format!("{} (param index {}) ", exception, self.index),
25-
),
26-
_ => error,
27-
})
26+
fn to_wasmtime_val(self) -> Result<wasmtime::Val, Error> {
27+
self.val
28+
.to_wasm_val(self.ty.into())
29+
.map_err(|error| match error {
30+
Error::Error(class, msg) => {
31+
Error::new(class, format!("{} (param at index {})", msg, self.index))
32+
}
33+
Error::Exception(exception) => Error::new(
34+
ExceptionClass::from_value(exception.class().into()).unwrap_or_else(arg_error),
35+
format!("{} (param at index {})", exception, self.index),
36+
),
37+
_ => error,
38+
})
2839
}
2940
}
3041

@@ -48,10 +59,55 @@ impl<'a> Params<'a> {
4859
pub fn to_vec(&self) -> Result<Vec<wasmtime::Val>, Error> {
4960
let mut vals = Vec::with_capacity(self.0.params().len());
5061
for (i, (param, value)) in self.0.params().zip(self.1.iter()).enumerate() {
51-
let param = Param::new(i, param.clone(), value);
62+
let i: u32 = i
63+
.try_into()
64+
.map_err(|_| Error::new(arg_error(), "too many params"))?;
65+
let param = Param::new(i, param.clone(), *value);
5266
vals.push(param.to_wasmtime_val()?);
5367
}
5468

5569
Ok(vals)
5670
}
5771
}
72+
73+
/// A [`wasmtime::ValType`] that is [`Copy`], so it can be stays on the stack
74+
///
75+
/// Note: this can be removed in Wasmtime 8.0 (see https://github.com/bytecodealliance/wasmtime/pull/6138)
76+
#[derive(Debug, Clone, Copy)]
77+
pub enum ValTypeCopy {
78+
I32,
79+
I64,
80+
F32,
81+
F64,
82+
V128,
83+
FuncRef,
84+
ExternRef,
85+
}
86+
87+
impl From<ValType> for ValTypeCopy {
88+
fn from(ty: ValType) -> Self {
89+
match ty {
90+
ValType::I32 => Self::I32,
91+
ValType::I64 => Self::I64,
92+
ValType::F32 => Self::F32,
93+
ValType::F64 => Self::F64,
94+
ValType::V128 => Self::V128,
95+
ValType::FuncRef => Self::FuncRef,
96+
ValType::ExternRef => Self::ExternRef,
97+
}
98+
}
99+
}
100+
101+
impl From<ValTypeCopy> for ValType {
102+
fn from(ty: ValTypeCopy) -> Self {
103+
match ty {
104+
ValTypeCopy::I32 => Self::I32,
105+
ValTypeCopy::I64 => Self::I64,
106+
ValTypeCopy::F32 => Self::F32,
107+
ValTypeCopy::F64 => Self::F64,
108+
ValTypeCopy::V128 => Self::V128,
109+
ValTypeCopy::FuncRef => Self::FuncRef,
110+
ValTypeCopy::ExternRef => Self::ExternRef,
111+
}
112+
}
113+
}

0 commit comments

Comments
 (0)