Implement the post-return attribute (#4297)

This commit implements the `post-return` feature of the canonical ABI in
the component model. This attribute is an optionally-specified function
which is to be executed after the return value has been processed by the
caller to optionally clean-up the return value. This enables, for
example, returning an allocated string and the host then knows how to
clean it up to prevent memory leaks in the original module.

The API exposed in this PR changes the prior `TypedFunc::call` API in
behavior but not in its signature. Previously the `TypedFunc::call`
method would set the `may_enter` flag on the way out, but now that
operation is deferred until a new `TypedFunc::post_return` method is
called. This means that once a method on an instance is invoked then
nothing else can be done on the instance until the `post_return` method
is called. Note that the method must be called irrespective of whether
the `post-return` canonical ABI option was specified or not. Internally
wasm will be invoked if necessary.

This is a pretty wonky and unergonomic API to work with. For now I
couldn't think of a better alternative that improved on the ergonomics.
In the theory that the raw Wasmtime bindings for a component may not be
used all that heavily (instead `wit-bindgen` would largely be used) I'm
hoping that this isn't too much of an issue in the future.

cc #4185
This commit is contained in:
Alex Crichton
2022-06-23 14:36:21 -05:00
committed by GitHub
parent fa36e86f2c
commit 3339dd1f01
12 changed files with 787 additions and 112 deletions

View File

@@ -3,7 +3,7 @@ use anyhow::Result;
use std::rc::Rc;
use std::sync::Arc;
use wasmtime::component::*;
use wasmtime::{Store, StoreContextMut, Trap, TrapCode};
use wasmtime::{AsContextMut, Store, StoreContextMut, Trap, TrapCode};
const CANON_32BIT_NAN: u32 = 0b01111111110000000000000000000000;
const CANON_64BIT_NAN: u64 = 0b0111111111111000000000000000000000000000000000000000000000000000;
@@ -32,7 +32,7 @@ fn thunks() -> Result<()> {
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
instance
.get_typed_func::<(), (), _>(&mut store, "thunk")?
.call(&mut store, ())?;
.call_and_post_return(&mut store, ())?;
let err = instance
.get_typed_func::<(), (), _>(&mut store, "thunk-trap")?
.call(&mut store, ())
@@ -193,28 +193,28 @@ fn integers() -> Result<()> {
// Passing in 100 is valid for all primitives
instance
.get_typed_func::<(u8,), (), _>(&mut store, "take-u8")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i8,), (), _>(&mut store, "take-s8")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(u16,), (), _>(&mut store, "take-u16")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i16,), (), _>(&mut store, "take-s16")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(u32,), (), _>(&mut store, "take-u32")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i32,), (), _>(&mut store, "take-s32")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(u64,), (), _>(&mut store, "take-u64")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
instance
.get_typed_func::<(i64,), (), _>(&mut store, "take-s64")?
.call(&mut store, (100,))?;
.call_and_post_return(&mut store, (100,))?;
// This specific wasm instance traps if any value other than 100 is passed
instance
@@ -262,49 +262,49 @@ fn integers() -> Result<()> {
assert_eq!(
instance
.get_typed_func::<(), u8, _>(&mut store, "ret-u8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i8, _>(&mut store, "ret-s8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), u16, _>(&mut store, "ret-u16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i16, _>(&mut store, "ret-s16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), u32, _>(&mut store, "ret-u32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i32, _>(&mut store, "ret-s32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), u64, _>(&mut store, "ret-u64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
assert_eq!(
instance
.get_typed_func::<(), i64, _>(&mut store, "ret-s64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0
);
@@ -312,49 +312,49 @@ fn integers() -> Result<()> {
assert_eq!(
instance
.get_typed_func::<(), u8, _>(&mut store, "retm1-u8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xff
);
assert_eq!(
instance
.get_typed_func::<(), i8, _>(&mut store, "retm1-s8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
assert_eq!(
instance
.get_typed_func::<(), u16, _>(&mut store, "retm1-u16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xffff
);
assert_eq!(
instance
.get_typed_func::<(), i16, _>(&mut store, "retm1-s16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
assert_eq!(
instance
.get_typed_func::<(), u32, _>(&mut store, "retm1-u32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xffffffff
);
assert_eq!(
instance
.get_typed_func::<(), i32, _>(&mut store, "retm1-s32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
assert_eq!(
instance
.get_typed_func::<(), u64, _>(&mut store, "retm1-u64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
0xffffffff_ffffffff
);
assert_eq!(
instance
.get_typed_func::<(), i64, _>(&mut store, "retm1-s64")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
-1
);
@@ -363,43 +363,59 @@ fn integers() -> Result<()> {
assert_eq!(
instance
.get_typed_func::<(), u8, _>(&mut store, "retbig-u8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as u8,
);
assert_eq!(
instance
.get_typed_func::<(), i8, _>(&mut store, "retbig-s8")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as i8,
);
assert_eq!(
instance
.get_typed_func::<(), u16, _>(&mut store, "retbig-u16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as u16,
);
assert_eq!(
instance
.get_typed_func::<(), i16, _>(&mut store, "retbig-s16")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as i16,
);
assert_eq!(
instance
.get_typed_func::<(), u32, _>(&mut store, "retbig-u32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret,
);
assert_eq!(
instance
.get_typed_func::<(), i32, _>(&mut store, "retbig-s32")?
.call(&mut store, ())?,
.call_and_post_return(&mut store, ())?,
ret as i32,
);
Ok(())
}
trait TypedFuncExt<P, R> {
fn call_and_post_return(&self, store: impl AsContextMut, params: P) -> Result<R>;
}
impl<P, R> TypedFuncExt<P, R> for TypedFunc<P, R>
where
P: ComponentParams + Lower,
R: Lift,
{
fn call_and_post_return(&self, mut store: impl AsContextMut, params: P) -> Result<R> {
let result = self.call(&mut store, params)?;
self.post_return(&mut store)?;
Ok(result)
}
}
#[test]
fn type_layers() -> Result<()> {
let component = r#"
@@ -425,19 +441,19 @@ fn type_layers() -> Result<()> {
instance
.get_typed_func::<(Box<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Box::new(2),))?;
.call_and_post_return(&mut store, (Box::new(2),))?;
instance
.get_typed_func::<(&u32,), (), _>(&mut store, "take-u32")?
.call(&mut store, (&2,))?;
.call_and_post_return(&mut store, (&2,))?;
instance
.get_typed_func::<(Rc<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Rc::new(2),))?;
.call_and_post_return(&mut store, (Rc::new(2),))?;
instance
.get_typed_func::<(Arc<u32>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (Arc::new(2),))?;
.call_and_post_return(&mut store, (Arc::new(2),))?;
instance
.get_typed_func::<(&Box<Arc<Rc<u32>>>,), (), _>(&mut store, "take-u32")?
.call(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?;
.call_and_post_return(&mut store, (&Box::new(Arc::new(Rc::new(2))),))?;
Ok(())
}
@@ -491,9 +507,13 @@ fn floats() -> Result<()> {
let u64_to_f64 = instance.get_typed_func::<(u64,), f64, _>(&mut store, "u64-to-f64")?;
assert_eq!(f32_to_u32.call(&mut store, (1.0,))?, 1.0f32.to_bits());
f32_to_u32.post_return(&mut store)?;
assert_eq!(f64_to_u64.call(&mut store, (2.0,))?, 2.0f64.to_bits());
f64_to_u64.post_return(&mut store)?;
assert_eq!(u32_to_f32.call(&mut store, (3.0f32.to_bits(),))?, 3.0);
u32_to_f32.post_return(&mut store)?;
assert_eq!(u64_to_f64.call(&mut store, (4.0f64.to_bits(),))?, 4.0);
u64_to_f64.post_return(&mut store)?;
assert_eq!(
u32_to_f32
@@ -501,21 +521,25 @@ fn floats() -> Result<()> {
.to_bits(),
CANON_32BIT_NAN
);
u32_to_f32.post_return(&mut store)?;
assert_eq!(
u64_to_f64
.call(&mut store, (CANON_64BIT_NAN | 1,))?
.to_bits(),
CANON_64BIT_NAN
);
u64_to_f64.post_return(&mut store)?;
assert_eq!(
f32_to_u32.call(&mut store, (f32::from_bits(CANON_32BIT_NAN | 1),))?,
CANON_32BIT_NAN
);
f32_to_u32.post_return(&mut store)?;
assert_eq!(
f64_to_u64.call(&mut store, (f64::from_bits(CANON_64BIT_NAN | 1),))?,
CANON_64BIT_NAN
);
f64_to_u64.post_return(&mut store)?;
Ok(())
}
@@ -546,10 +570,15 @@ fn bools() -> Result<()> {
let bool_to_u32 = instance.get_typed_func::<(bool,), u32, _>(&mut store, "bool-to-u32")?;
assert_eq!(bool_to_u32.call(&mut store, (false,))?, 0);
bool_to_u32.post_return(&mut store)?;
assert_eq!(bool_to_u32.call(&mut store, (true,))?, 1);
bool_to_u32.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (0,))?, false);
u32_to_bool.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (1,))?, true);
u32_to_bool.post_return(&mut store)?;
assert_eq!(u32_to_bool.call(&mut store, (2,))?, true);
u32_to_bool.post_return(&mut store)?;
Ok(())
}
@@ -581,7 +610,9 @@ fn chars() -> Result<()> {
let mut roundtrip = |x: char| -> Result<()> {
assert_eq!(char_to_u32.call(&mut store, (x,))?, x as u32);
char_to_u32.post_return(&mut store)?;
assert_eq!(u32_to_char.call(&mut store, (x as u32,))?, x);
u32_to_char.post_return(&mut store)?;
Ok(())
};
@@ -644,7 +675,7 @@ fn tuple_result() -> Result<()> {
let input = (-1, 100, 3.0, 100.0);
let output = instance
.get_typed_func::<(i8, u16, f32, f64), (i8, u16, f32, f64), _>(&mut store, "tuple")?
.call(&mut store, input)?;
.call_and_post_return(&mut store, input)?;
assert_eq!(input, output);
let invalid_func =
@@ -735,16 +766,20 @@ fn strings() -> Result<()> {
let mut roundtrip = |x: &str| -> Result<()> {
let ret = list8_to_str.call(&mut store, (x.as_bytes(),))?;
assert_eq!(ret.to_str(&store)?, x);
list8_to_str.post_return(&mut store)?;
let utf16 = x.encode_utf16().collect::<Vec<_>>();
let ret = list16_to_str.call(&mut store, (&utf16[..],))?;
assert_eq!(ret.to_str(&store)?, x);
list16_to_str.post_return(&mut store)?;
let ret = str_to_list8.call(&mut store, (x,))?;
assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, x.as_bytes());
str_to_list8.post_return(&mut store)?;
let ret = str_to_list16.call(&mut store, (x,))?;
assert_eq!(ret.iter(&store).collect::<Result<Vec<_>>>()?, utf16,);
str_to_list16.post_return(&mut store)?;
Ok(())
};
@@ -758,22 +793,27 @@ fn strings() -> Result<()> {
let ret = list8_to_str.call(&mut store, (b"\xff",))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("invalid utf-8"), "{}", err);
list8_to_str.post_return(&mut store)?;
let ret = list8_to_str.call(&mut store, (b"hello there \xff invalid",))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("invalid utf-8"), "{}", err);
list8_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xd800],))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xdfff],))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
let ret = list16_to_str.call(&mut store, (&[0xd800, 0xff00],))?;
let err = ret.to_str(&store).unwrap_err();
assert!(err.to_string().contains("unpaired surrogate"), "{}", err);
list16_to_str.post_return(&mut store)?;
Ok(())
}
@@ -1123,10 +1163,10 @@ fn some_traps() -> Result<()> {
instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[],))?;
.call_and_post_return(&mut store, (&[],))?;
instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[1, 2, 3, 4],))?;
.call_and_post_return(&mut store, (&[1, 2, 3, 4],))?;
let err = instance
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-end-oob")?
.call(&mut store, (&[1, 2, 3, 4, 5],))
@@ -1134,10 +1174,10 @@ fn some_traps() -> Result<()> {
assert_oob(&err);
instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("",))?;
.call_and_post_return(&mut store, ("",))?;
instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("abcd",))?;
.call_and_post_return(&mut store, ("abcd",))?;
let err = instance
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-end-oob")?
.call(&mut store, ("abcde",))
@@ -1216,12 +1256,15 @@ fn char_bool_memory() -> Result<()> {
let ret = func.call(&mut store, (0, 'a' as u32))?;
assert_eq!(ret, (false, 'a'));
func.post_return(&mut store)?;
let ret = func.call(&mut store, (1, '🍰' as u32))?;
assert_eq!(ret, (true, '🍰'));
func.post_return(&mut store)?;
let ret = func.call(&mut store, (2, 'a' as u32))?;
assert_eq!(ret, (true, 'a'));
func.post_return(&mut store)?;
assert!(func.call(&mut store, (0, 0xd800)).is_err());
@@ -1437,22 +1480,30 @@ fn option() -> Result<()> {
let option_unit_to_u32 =
instance.get_typed_func::<(Option<()>,), u32, _>(&mut store, "option-unit-to-u32")?;
assert_eq!(option_unit_to_u32.call(&mut store, (None,))?, 0);
option_unit_to_u32.post_return(&mut store)?;
assert_eq!(option_unit_to_u32.call(&mut store, (Some(()),))?, 1);
option_unit_to_u32.post_return(&mut store)?;
let option_u8_to_tuple = instance
.get_typed_func::<(Option<u8>,), (u32, u32), _>(&mut store, "option-u8-to-tuple")?;
assert_eq!(option_u8_to_tuple.call(&mut store, (None,))?, (0, 0));
option_u8_to_tuple.post_return(&mut store)?;
assert_eq!(option_u8_to_tuple.call(&mut store, (Some(0),))?, (1, 0));
option_u8_to_tuple.post_return(&mut store)?;
assert_eq!(option_u8_to_tuple.call(&mut store, (Some(100),))?, (1, 100));
option_u8_to_tuple.post_return(&mut store)?;
let option_u32_to_tuple = instance
.get_typed_func::<(Option<u32>,), (u32, u32), _>(&mut store, "option-u32-to-tuple")?;
assert_eq!(option_u32_to_tuple.call(&mut store, (None,))?, (0, 0));
option_u32_to_tuple.post_return(&mut store)?;
assert_eq!(option_u32_to_tuple.call(&mut store, (Some(0),))?, (1, 0));
option_u32_to_tuple.post_return(&mut store)?;
assert_eq!(
option_u32_to_tuple.call(&mut store, (Some(100),))?,
(1, 100)
);
option_u32_to_tuple.post_return(&mut store)?;
let option_string_to_tuple = instance.get_typed_func::<(Option<&str>,), (u32, WasmStr), _>(
&mut store,
@@ -1461,45 +1512,59 @@ fn option() -> Result<()> {
let (a, b) = option_string_to_tuple.call(&mut store, (None,))?;
assert_eq!(a, 0);
assert_eq!(b.to_str(&store)?, "");
option_string_to_tuple.post_return(&mut store)?;
let (a, b) = option_string_to_tuple.call(&mut store, (Some(""),))?;
assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "");
option_string_to_tuple.post_return(&mut store)?;
let (a, b) = option_string_to_tuple.call(&mut store, (Some("hello"),))?;
assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "hello");
option_string_to_tuple.post_return(&mut store)?;
let to_option_unit =
instance.get_typed_func::<(u32,), Option<()>, _>(&mut store, "to-option-unit")?;
assert_eq!(to_option_unit.call(&mut store, (0,))?, None);
to_option_unit.post_return(&mut store)?;
assert_eq!(to_option_unit.call(&mut store, (1,))?, Some(()));
to_option_unit.post_return(&mut store)?;
let err = to_option_unit.call(&mut store, (2,)).unwrap_err();
assert!(err.to_string().contains("invalid option"), "{}", err);
let to_option_u8 =
instance.get_typed_func::<(u32, u32), Option<u8>, _>(&mut store, "to-option-u8")?;
assert_eq!(to_option_u8.call(&mut store, (0x00_00, 0))?, None);
to_option_u8.post_return(&mut store)?;
assert_eq!(to_option_u8.call(&mut store, (0x00_01, 0))?, Some(0));
to_option_u8.post_return(&mut store)?;
assert_eq!(to_option_u8.call(&mut store, (0xfd_01, 0))?, Some(0xfd));
to_option_u8.post_return(&mut store)?;
assert!(to_option_u8.call(&mut store, (0x00_02, 0)).is_err());
let to_option_u32 =
instance.get_typed_func::<(u32, u32), Option<u32>, _>(&mut store, "to-option-u32")?;
assert_eq!(to_option_u32.call(&mut store, (0, 0))?, None);
to_option_u32.post_return(&mut store)?;
assert_eq!(to_option_u32.call(&mut store, (1, 0))?, Some(0));
to_option_u32.post_return(&mut store)?;
assert_eq!(
to_option_u32.call(&mut store, (1, 0x1234fead))?,
Some(0x1234fead)
);
to_option_u32.post_return(&mut store)?;
assert!(to_option_u32.call(&mut store, (2, 0)).is_err());
let to_option_string = instance
.get_typed_func::<(u32, &str), Option<WasmStr>, _>(&mut store, "to-option-string")?;
let ret = to_option_string.call(&mut store, (0, ""))?;
assert!(ret.is_none());
to_option_string.post_return(&mut store)?;
let ret = to_option_string.call(&mut store, (1, ""))?;
assert_eq!(ret.unwrap().to_str(&store)?, "");
to_option_string.post_return(&mut store)?;
let ret = to_option_string.call(&mut store, (1, "cheesecake"))?;
assert_eq!(ret.unwrap().to_str(&store)?, "cheesecake");
to_option_string.post_return(&mut store)?;
assert!(to_option_string.call(&mut store, (2, "")).is_err());
Ok(())
@@ -1592,15 +1657,19 @@ fn expected() -> Result<()> {
let take_expected_unit =
instance.get_typed_func::<(Result<(), ()>,), u32, _>(&mut store, "take-expected-unit")?;
assert_eq!(take_expected_unit.call(&mut store, (Ok(()),))?, 0);
take_expected_unit.post_return(&mut store)?;
assert_eq!(take_expected_unit.call(&mut store, (Err(()),))?, 1);
take_expected_unit.post_return(&mut store)?;
let take_expected_u8_f32 = instance
.get_typed_func::<(Result<u8, f32>,), (u32, u32), _>(&mut store, "take-expected-u8-f32")?;
assert_eq!(take_expected_u8_f32.call(&mut store, (Ok(1),))?, (0, 1));
take_expected_u8_f32.post_return(&mut store)?;
assert_eq!(
take_expected_u8_f32.call(&mut store, (Err(2.0),))?,
(1, 2.0f32.to_bits())
);
take_expected_u8_f32.post_return(&mut store)?;
let take_expected_string = instance
.get_typed_func::<(Result<&str, &[u8]>,), (u32, WasmStr), _>(
@@ -1610,27 +1679,35 @@ fn expected() -> Result<()> {
let (a, b) = take_expected_string.call(&mut store, (Ok("hello"),))?;
assert_eq!(a, 0);
assert_eq!(b.to_str(&store)?, "hello");
take_expected_string.post_return(&mut store)?;
let (a, b) = take_expected_string.call(&mut store, (Err(b"goodbye"),))?;
assert_eq!(a, 1);
assert_eq!(b.to_str(&store)?, "goodbye");
take_expected_string.post_return(&mut store)?;
let to_expected_unit =
instance.get_typed_func::<(u32,), Result<(), ()>, _>(&mut store, "to-expected-unit")?;
assert_eq!(to_expected_unit.call(&mut store, (0,))?, Ok(()));
to_expected_unit.post_return(&mut store)?;
assert_eq!(to_expected_unit.call(&mut store, (1,))?, Err(()));
to_expected_unit.post_return(&mut store)?;
let err = to_expected_unit.call(&mut store, (2,)).unwrap_err();
assert!(err.to_string().contains("invalid expected"), "{}", err);
let to_expected_s16_f32 = instance
.get_typed_func::<(u32, u32), Result<i16, f32>, _>(&mut store, "to-expected-s16-f32")?;
assert_eq!(to_expected_s16_f32.call(&mut store, (0, 0))?, Ok(0));
to_expected_s16_f32.post_return(&mut store)?;
assert_eq!(to_expected_s16_f32.call(&mut store, (0, 100))?, Ok(100));
to_expected_s16_f32.post_return(&mut store)?;
assert_eq!(
to_expected_s16_f32.call(&mut store, (1, 1.0f32.to_bits()))?,
Err(1.0)
);
to_expected_s16_f32.post_return(&mut store)?;
let ret = to_expected_s16_f32.call(&mut store, (1, CANON_32BIT_NAN | 1))?;
assert_eq!(ret.unwrap_err().to_bits(), CANON_32BIT_NAN);
to_expected_s16_f32.post_return(&mut store)?;
assert!(to_expected_s16_f32.call(&mut store, (2, 0)).is_err());
Ok(())

View File

@@ -0,0 +1,259 @@
use anyhow::Result;
use wasmtime::component::*;
use wasmtime::{Store, StoreContextMut};
#[test]
fn invalid_api() -> Result<()> {
let component = r#"
(component
(core module $m
(func (export "thunk1"))
(func (export "thunk2"))
)
(core instance $i (instantiate $m))
(func (export "thunk1")
(canon lift (core func $i "thunk1"))
)
(func (export "thunk2")
(canon lift (core func $i "thunk2"))
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, ());
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let thunk1 = instance.get_typed_func::<(), (), _>(&mut store, "thunk1")?;
let thunk2 = instance.get_typed_func::<(), (), _>(&mut store, "thunk2")?;
// Ensure that we can't call `post_return` before doing anything
let msg = "post_return can only be called after a function has previously been called";
assert_panics(|| drop(thunk1.post_return(&mut store)), msg);
assert_panics(|| drop(thunk2.post_return(&mut store)), msg);
// Schedule a "needs post return"
thunk1.call(&mut store, ())?;
// Ensure that we can't reenter the instance through either this function or
// another one.
let err = thunk1.call(&mut store, ()).unwrap_err();
assert!(
err.to_string()
.contains("cannot reenter component instance"),
"{}",
err
);
let err = thunk2.call(&mut store, ()).unwrap_err();
assert!(
err.to_string()
.contains("cannot reenter component instance"),
"{}",
err
);
// Calling post-return on the wrong function should panic
assert_panics(
|| drop(thunk2.post_return(&mut store)),
"calling post_return on wrong function",
);
// Actually execute the post-return
thunk1.post_return(&mut store)?;
// And now post-return should be invalid again.
assert_panics(|| drop(thunk1.post_return(&mut store)), msg);
assert_panics(|| drop(thunk2.post_return(&mut store)), msg);
Ok(())
}
#[track_caller]
fn assert_panics(f: impl FnOnce(), msg: &str) {
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)) {
Ok(()) => panic!("expected closure to panic"),
Err(e) => match e.downcast::<String>() {
Ok(s) => {
assert!(s.contains(msg), "bad panic: {}", s);
}
Err(e) => match e.downcast::<&'static str>() {
Ok(s) => assert!(s.contains(msg), "bad panic: {}", s),
Err(_) => panic!("bad panic"),
},
},
}
}
#[test]
fn invoke_post_return() -> Result<()> {
let component = r#"
(component
(import "f" (func $f))
(core func $f_lower
(canon lower (func $f))
)
(core module $m
(import "" "" (func $f))
(func (export "thunk"))
(func $post_return
call $f)
(export "post-return" (func $post_return))
)
(core instance $i (instantiate $m
(with "" (instance
(export "" (func $f_lower))
))
))
(func (export "thunk")
(canon lift
(core func $i "thunk")
(post-return (func $i "post-return"))
)
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let mut linker = Linker::new(&engine);
linker
.root()
.func_wrap("f", |mut store: StoreContextMut<'_, bool>| -> Result<()> {
assert!(!*store.data());
*store.data_mut() = true;
Ok(())
})?;
let instance = linker.instantiate(&mut store, &component)?;
let thunk = instance.get_typed_func::<(), (), _>(&mut store, "thunk")?;
assert!(!*store.data());
thunk.call(&mut store, ())?;
assert!(!*store.data());
thunk.post_return(&mut store)?;
assert!(*store.data());
Ok(())
}
#[test]
fn post_return_all_types() -> Result<()> {
let component = r#"
(component
(core module $m
(func (export "i32") (result i32)
i32.const 1)
(func (export "i64") (result i64)
i64.const 2)
(func (export "f32") (result f32)
f32.const 3)
(func (export "f64") (result f64)
f64.const 4)
(func (export "post-i32") (param i32)
local.get 0
i32.const 1
i32.ne
if unreachable end)
(func (export "post-i64") (param i64)
local.get 0
i64.const 2
i64.ne
if unreachable end)
(func (export "post-f32") (param f32)
local.get 0
f32.const 3
f32.ne
if unreachable end)
(func (export "post-f64") (param f64)
local.get 0
f64.const 4
f64.ne
if unreachable end)
)
(core instance $i (instantiate $m))
(func (export "i32") (result u32)
(canon lift (core func $i "i32") (post-return (func $i "post-i32")))
)
(func (export "i64") (result u64)
(canon lift (core func $i "i64") (post-return (func $i "post-i64")))
)
(func (export "f32") (result float32)
(canon lift (core func $i "f32") (post-return (func $i "post-f32")))
)
(func (export "f64") (result float64)
(canon lift (core func $i "f64") (post-return (func $i "post-f64")))
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let i32 = instance.get_typed_func::<(), u32, _>(&mut store, "i32")?;
let i64 = instance.get_typed_func::<(), u64, _>(&mut store, "i64")?;
let f32 = instance.get_typed_func::<(), f32, _>(&mut store, "f32")?;
let f64 = instance.get_typed_func::<(), f64, _>(&mut store, "f64")?;
assert_eq!(i32.call(&mut store, ())?, 1);
i32.post_return(&mut store)?;
assert_eq!(i64.call(&mut store, ())?, 2);
i64.post_return(&mut store)?;
assert_eq!(f32.call(&mut store, ())?, 3.);
f32.post_return(&mut store)?;
assert_eq!(f64.call(&mut store, ())?, 4.);
f64.post_return(&mut store)?;
Ok(())
}
#[test]
fn post_return_string() -> Result<()> {
let component = r#"
(component
(core module $m
(memory (export "memory") 1)
(func (export "get") (result i32)
(i32.store offset=0 (i32.const 8) (i32.const 100))
(i32.store offset=4 (i32.const 8) (i32.const 11))
i32.const 8
)
(func (export "post") (param i32)
local.get 0
i32.const 8
i32.ne
if unreachable end)
(data (i32.const 100) "hello world")
)
(core instance $i (instantiate $m))
(func (export "get") (result string)
(canon lift
(core func $i "get")
(post-return (func $i "post"))
(memory $i "memory")
)
)
)
"#;
let engine = super::engine();
let component = Component::new(&engine, component)?;
let mut store = Store::new(&engine, false);
let instance = Linker::new(&engine).instantiate(&mut store, &component)?;
let get = instance.get_typed_func::<(), WasmStr, _>(&mut store, "get")?;
let s = get.call(&mut store, ())?;
assert_eq!(s.to_str(&store)?, "hello world");
get.post_return(&mut store)?;
Ok(())
}