diff --git a/crates/api/src/trampoline/func.rs b/crates/api/src/trampoline/func.rs index d6435c8d79..446e6f8591 100644 --- a/crates/api/src/trampoline/func.rs +++ b/crates/api/src/trampoline/func.rs @@ -129,8 +129,15 @@ unsafe extern "C" fn stub_fn( .downcast_ref::() .expect("state"); state.func.call(&args, &mut returns)?; + + let module = instance.module_ref(); + let signature = &module.signatures[module.functions[FuncIndex::new(call_id as usize)]]; for (i, ret) in returns.iter_mut().enumerate() { - // TODO check signature.returns[i].value_type ? + if ret.ty().get_wasmtime_type() != Some(signature.returns[i].value_type) { + return Err(Trap::new( + "`Callable` attempted to return an incompatible value", + )); + } ret.write_value_to(values_vec.add(i)); } Ok(()) diff --git a/crates/api/tests/import_calling_export.rs b/crates/api/tests/import_calling_export.rs index b13f868dbc..c755e9bb58 100644 --- a/crates/api/tests/import_calling_export.rs +++ b/crates/api/tests/import_calling_export.rs @@ -64,3 +64,53 @@ fn test_import_calling_export() { run_func.call(&[]).expect("expected function not to trap"); } + +#[test] +fn test_returns_incorrect_type() { + const WAT: &str = r#" + (module + (import "env" "evil" (func $evil (result i32))) + (func (export "run") (result i32) + (call $evil) + ) + ) + "#; + + struct EvilCallback; + + impl Callable for EvilCallback { + fn call(&self, _params: &[Val], results: &mut [Val]) -> Result<(), Trap> { + // Evil! Returns I64 here instead of promised in the signature I32. + results[0] = Val::I64(228); + Ok(()) + } + } + + let store = Store::default(); + let module = Module::new(&store, WAT).expect("failed to create module"); + + let callback = Rc::new(EvilCallback); + + let callback_func = Func::new( + &store, + FuncType::new(Box::new([]), Box::new([ValType::I32])), + callback.clone(), + ); + + let imports = vec![callback_func.into()]; + let instance = + Instance::new(&module, imports.as_slice()).expect("failed to instantiate module"); + + let exports = instance.exports(); + assert!(!exports.is_empty()); + + let run_func = exports[0] + .func() + .expect("expected a run func in the module"); + + let trap = run_func.call(&[]).expect_err("the execution should fail"); + assert_eq!( + trap.message(), + "`Callable` attempted to return an incompatible value" + ); +}