wasmtime_wiggle: support for async, and add an integration test
This commit is contained in:
@@ -17,10 +17,22 @@ witx = { version = "0.9", path = "../../wasi-common/WASI/tools/witx", optional =
|
||||
wiggle = { path = "..", version = "0.23.0" }
|
||||
wiggle-borrow = { path = "../borrow", version = "0.23.0" }
|
||||
|
||||
[dev-dependencies]
|
||||
anyhow = "1"
|
||||
proptest = "0.10"
|
||||
|
||||
[[test]]
|
||||
name = "atoms_async"
|
||||
path = "tests/atoms_async.rs"
|
||||
required-features = ["async", "wasmtime/wat"]
|
||||
|
||||
[badges]
|
||||
maintenance = { status = "actively-developed" }
|
||||
|
||||
[features]
|
||||
# Async support for wasmtime
|
||||
async = [ 'wasmtime/async', 'wasmtime-wiggle-macro/async' ]
|
||||
|
||||
# The wiggle proc-macro emits some code (inside `pub mod metadata`) guarded
|
||||
# by the `wiggle_metadata` feature flag. We use this feature flag so that
|
||||
# users of wiggle are not forced to take a direct dependency on the `witx`
|
||||
@@ -33,4 +45,4 @@ wiggle_metadata = ['witx', "wiggle/wiggle_metadata"]
|
||||
# the logs out of wiggle-generated libraries.
|
||||
tracing_log = [ "wiggle/tracing_log" ]
|
||||
|
||||
default = ["wiggle_metadata" ]
|
||||
default = ["wiggle_metadata", "async"]
|
||||
|
||||
@@ -24,3 +24,6 @@ proc-macro2 = "1.0"
|
||||
[badges]
|
||||
maintenance = { status = "actively-developed" }
|
||||
|
||||
[features]
|
||||
async = []
|
||||
default = []
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub use wiggle_generate::config::AsyncConf;
|
||||
use {
|
||||
proc_macro2::Span,
|
||||
std::collections::HashMap,
|
||||
@@ -7,15 +8,16 @@ use {
|
||||
punctuated::Punctuated,
|
||||
Error, Ident, Path, Result, Token,
|
||||
},
|
||||
wiggle_generate::config::{CtxConf, WitxConf},
|
||||
wiggle_generate::config::WitxConf,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub target: TargetConf,
|
||||
pub witx: WitxConf,
|
||||
pub ctx: CtxConf,
|
||||
pub modules: ModulesConf,
|
||||
#[cfg(feature = "async")]
|
||||
pub async_: AsyncConf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -24,6 +26,8 @@ pub enum ConfigField {
|
||||
Witx(WitxConf),
|
||||
Ctx(CtxConf),
|
||||
Modules(ModulesConf),
|
||||
#[cfg(feature = "async")]
|
||||
Async(AsyncConf),
|
||||
}
|
||||
|
||||
mod kw {
|
||||
@@ -35,6 +39,7 @@ mod kw {
|
||||
syn::custom_keyword!(name);
|
||||
syn::custom_keyword!(docs);
|
||||
syn::custom_keyword!(function_override);
|
||||
syn::custom_keyword!(async_);
|
||||
}
|
||||
|
||||
impl Parse for ConfigField {
|
||||
@@ -60,6 +65,20 @@ impl Parse for ConfigField {
|
||||
input.parse::<kw::modules>()?;
|
||||
input.parse::<Token![:]>()?;
|
||||
Ok(ConfigField::Modules(input.parse()?))
|
||||
} else if lookahead.peek(kw::async_) {
|
||||
input.parse::<kw::async_>()?;
|
||||
input.parse::<Token![:]>()?;
|
||||
#[cfg(feature = "async")]
|
||||
{
|
||||
Ok(ConfigField::Async(input.parse()?))
|
||||
}
|
||||
#[cfg(not(feature = "async"))]
|
||||
{
|
||||
Err(syn::Error::new(
|
||||
input.span(),
|
||||
"async_ not supported, enable cargo feature \"async\"",
|
||||
))
|
||||
}
|
||||
} else {
|
||||
Err(lookahead.error())
|
||||
}
|
||||
@@ -72,6 +91,8 @@ impl Config {
|
||||
let mut witx = None;
|
||||
let mut ctx = None;
|
||||
let mut modules = None;
|
||||
#[cfg(feature = "async")]
|
||||
let mut async_ = None;
|
||||
for f in fields {
|
||||
match f {
|
||||
ConfigField::Target(c) => {
|
||||
@@ -98,6 +119,13 @@ impl Config {
|
||||
}
|
||||
modules = Some(c);
|
||||
}
|
||||
#[cfg(feature = "async")]
|
||||
ConfigField::Async(c) => {
|
||||
if async_.is_some() {
|
||||
return Err(Error::new(err_loc, "duplicate `async_` field"));
|
||||
}
|
||||
async_ = Some(c);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Config {
|
||||
@@ -105,6 +133,8 @@ impl Config {
|
||||
witx: witx.ok_or_else(|| Error::new(err_loc, "`witx` field required"))?,
|
||||
ctx: ctx.ok_or_else(|| Error::new(err_loc, "`ctx` field required"))?,
|
||||
modules: modules.ok_or_else(|| Error::new(err_loc, "`modules` field required"))?,
|
||||
#[cfg(feature = "async")]
|
||||
async_: async_.unwrap_or_default(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -128,6 +158,19 @@ impl Parse for Config {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CtxConf {
|
||||
pub name: syn::Type,
|
||||
}
|
||||
|
||||
impl Parse for CtxConf {
|
||||
fn parse(input: ParseStream) -> Result<Self> {
|
||||
Ok(CtxConf {
|
||||
name: input.parse()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TargetConf {
|
||||
pub path: Path,
|
||||
|
||||
@@ -6,7 +6,7 @@ use wiggle_generate::Names;
|
||||
|
||||
mod config;
|
||||
|
||||
use config::{ModuleConf, TargetConf};
|
||||
use config::{AsyncConf, ModuleConf, TargetConf};
|
||||
|
||||
/// Define the structs required to integrate a Wiggle implementation with Wasmtime.
|
||||
///
|
||||
@@ -46,13 +46,25 @@ use config::{ModuleConf, TargetConf};
|
||||
pub fn wasmtime_integration(args: TokenStream) -> TokenStream {
|
||||
let config = parse_macro_input!(args as config::Config);
|
||||
let doc = config.load_document();
|
||||
let names = Names::new(&config.ctx.name, quote!(wasmtime_wiggle));
|
||||
let names = Names::new(quote!(wasmtime_wiggle));
|
||||
|
||||
#[cfg(feature = "async")]
|
||||
let async_config = config.async_.clone();
|
||||
#[cfg(not(feature = "async"))]
|
||||
let async_config = AsyncConf::default();
|
||||
|
||||
let modules = config.modules.iter().map(|(name, module_conf)| {
|
||||
let module = doc
|
||||
.module(&witx::Id::new(name))
|
||||
.unwrap_or_else(|| panic!("witx document did not contain module named '{}'", name));
|
||||
generate_module(&module, &module_conf, &names, &config.target)
|
||||
generate_module(
|
||||
&module,
|
||||
&module_conf,
|
||||
&names,
|
||||
&config.target,
|
||||
&config.ctx.name,
|
||||
&async_config,
|
||||
)
|
||||
});
|
||||
quote!( #(#modules)* ).into()
|
||||
}
|
||||
@@ -62,6 +74,8 @@ fn generate_module(
|
||||
module_conf: &ModuleConf,
|
||||
names: &Names,
|
||||
target_conf: &TargetConf,
|
||||
ctx_type: &syn::Type,
|
||||
async_conf: &AsyncConf,
|
||||
) -> TokenStream2 {
|
||||
let fields = module.funcs().map(|f| {
|
||||
let name_ident = names.func(&f.name);
|
||||
@@ -88,9 +102,14 @@ fn generate_module(
|
||||
let module_id = names.module(&module.name);
|
||||
let target_module = quote! { #target_path::#module_id };
|
||||
|
||||
let ctor_externs = module
|
||||
.funcs()
|
||||
.map(|f| generate_func(&f, names, &target_module));
|
||||
let ctor_externs = module.funcs().map(|f| {
|
||||
generate_func(
|
||||
&f,
|
||||
names,
|
||||
&target_module,
|
||||
async_conf.is_async(module.name.as_str(), f.name.as_str()),
|
||||
)
|
||||
});
|
||||
|
||||
let type_name = module_conf.name.clone();
|
||||
let type_docs = module_conf
|
||||
@@ -107,8 +126,6 @@ contained in the `cx` parameter.",
|
||||
module_conf.name.to_string()
|
||||
);
|
||||
|
||||
let ctx_type = names.ctx_type();
|
||||
|
||||
quote! {
|
||||
#type_docs
|
||||
pub struct #type_name {
|
||||
@@ -150,6 +167,7 @@ fn generate_func(
|
||||
func: &witx::InterfaceFunc,
|
||||
names: &Names,
|
||||
target_module: &TokenStream2,
|
||||
is_async: bool,
|
||||
) -> TokenStream2 {
|
||||
let name_ident = names.func(&func.name);
|
||||
|
||||
@@ -172,11 +190,9 @@ fn generate_func(
|
||||
|
||||
let runtime = names.runtime_mod();
|
||||
|
||||
quote! {
|
||||
let my_cx = cx.clone();
|
||||
let #name_ident = wasmtime::Func::wrap(
|
||||
store,
|
||||
move |caller: wasmtime::Caller<'_> #(,#arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
|
||||
let await_ = if is_async { quote!(.await) } else { quote!() };
|
||||
|
||||
let closure_body = quote! {
|
||||
unsafe {
|
||||
let mem = match caller.get_export("memory") {
|
||||
Some(wasmtime::Extern::Memory(m)) => m,
|
||||
@@ -186,17 +202,40 @@ fn generate_func(
|
||||
};
|
||||
let mem = #runtime::WasmtimeGuestMemory::new(mem);
|
||||
let result = #target_module::#name_ident(
|
||||
&mut my_cx.borrow_mut(),
|
||||
&mut *my_cx.borrow_mut(),
|
||||
&mem,
|
||||
#(#arg_names),*
|
||||
);
|
||||
) #await_;
|
||||
match result {
|
||||
Ok(r) => Ok(r.into()),
|
||||
Err(wasmtime_wiggle::Trap::String(err)) => Err(wasmtime::Trap::new(err)),
|
||||
Err(wasmtime_wiggle::Trap::I32Exit(err)) => Err(wasmtime::Trap::i32_exit(err)),
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
if is_async {
|
||||
let wrapper = quote::format_ident!("wrap{}_async", params.len());
|
||||
quote! {
|
||||
let #name_ident = wasmtime::Func::#wrapper(
|
||||
store,
|
||||
cx.clone(),
|
||||
move |caller: wasmtime::Caller<'_>, my_cx: &Rc<RefCell<_>> #(,#arg_decls)*|
|
||||
-> Box<dyn std::future::Future<Output = Result<#ret_ty, wasmtime::Trap>>>
|
||||
{
|
||||
Box::new(async move { #closure_body })
|
||||
}
|
||||
);
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
let my_cx = cx.clone();
|
||||
let #name_ident = wasmtime::Func::wrap(
|
||||
store,
|
||||
move |caller: wasmtime::Caller<'_> #(,#arg_decls)*| -> Result<#ret_ty, wasmtime::Trap> {
|
||||
#closure_body
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
25
crates/wiggle/wasmtime/tests/atoms.witx
Normal file
25
crates/wiggle/wasmtime/tests/atoms.witx
Normal file
@@ -0,0 +1,25 @@
|
||||
|
||||
(typename $errno
|
||||
(enum (@witx tag u32)
|
||||
;;; Success
|
||||
$ok
|
||||
;;; Invalid argument
|
||||
$invalid_arg
|
||||
;;; I really don't want to
|
||||
$dont_want_to
|
||||
;;; I am physically unable to
|
||||
$physically_unable
|
||||
;;; Well, that's a picket line alright!
|
||||
$picket_line))
|
||||
|
||||
(typename $alias_to_float f32)
|
||||
|
||||
(module $atoms
|
||||
(@interface func (export "int_float_args")
|
||||
(param $an_int u32)
|
||||
(param $an_float f32)
|
||||
(result $error (expected (error $errno))))
|
||||
(@interface func (export "double_int_return_float")
|
||||
(param $an_int u32)
|
||||
(result $error (expected $alias_to_float (error $errno))))
|
||||
)
|
||||
175
crates/wiggle/wasmtime/tests/atoms_async.rs
Normal file
175
crates/wiggle/wasmtime/tests/atoms_async.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use std::cell::RefCell;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::rc::Rc;
|
||||
use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
|
||||
|
||||
wasmtime_wiggle::from_witx!({
|
||||
witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"],
|
||||
async_: {
|
||||
atoms::{double_int_return_float}
|
||||
}
|
||||
});
|
||||
|
||||
wasmtime_wiggle::wasmtime_integration!({
|
||||
target: crate,
|
||||
witx: ["$CARGO_MANIFEST_DIR/tests/atoms.witx"],
|
||||
ctx: Ctx,
|
||||
modules: { atoms => { name: Atoms } },
|
||||
async_: {
|
||||
atoms::double_int_return_float
|
||||
}
|
||||
});
|
||||
|
||||
pub struct Ctx;
|
||||
impl wiggle::GuestErrorType for types::Errno {
|
||||
fn success() -> Self {
|
||||
types::Errno::Ok
|
||||
}
|
||||
}
|
||||
|
||||
#[wasmtime_wiggle::async_trait(?Send)]
|
||||
impl atoms::Atoms for Ctx {
|
||||
fn int_float_args(&self, an_int: u32, an_float: f32) -> Result<(), types::Errno> {
|
||||
println!("INT FLOAT ARGS: {} {}", an_int, an_float);
|
||||
Ok(())
|
||||
}
|
||||
async fn double_int_return_float(
|
||||
&self,
|
||||
an_int: u32,
|
||||
) -> Result<types::AliasToFloat, types::Errno> {
|
||||
Ok((an_int as f32) * 2.0)
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_host_func() {
|
||||
let store = async_store();
|
||||
|
||||
let ctx = Rc::new(RefCell::new(Ctx));
|
||||
let atoms = Atoms::new(&store, ctx.clone());
|
||||
|
||||
let shim_mod = shim_module(&store);
|
||||
let mut linker = wasmtime::Linker::new(&store);
|
||||
atoms.add_to_linker(&mut linker).unwrap();
|
||||
let shim_inst = run(linker.instantiate_async(&shim_mod)).unwrap();
|
||||
|
||||
let results = run(shim_inst
|
||||
.get_func("int_float_args_shim")
|
||||
.unwrap()
|
||||
.call_async(&[0i32.into(), 123.45f32.into()]))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1, "one return value");
|
||||
assert_eq!(
|
||||
results[0].unwrap_i32(),
|
||||
types::Errno::Ok as i32,
|
||||
"int_float_args errno"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_async_host_func() {
|
||||
let store = async_store();
|
||||
|
||||
let ctx = Rc::new(RefCell::new(Ctx));
|
||||
let atoms = Atoms::new(&store, ctx.clone());
|
||||
|
||||
let shim_mod = shim_module(&store);
|
||||
let mut linker = wasmtime::Linker::new(&store);
|
||||
atoms.add_to_linker(&mut linker).unwrap();
|
||||
let shim_inst = run(linker.instantiate_async(&shim_mod)).unwrap();
|
||||
|
||||
let input: i32 = 123;
|
||||
let result_location: i32 = 0;
|
||||
|
||||
let results = run(shim_inst
|
||||
.get_func("double_int_return_float_shim")
|
||||
.unwrap()
|
||||
.call_async(&[input.into(), result_location.into()]))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(results.len(), 1, "one return value");
|
||||
assert_eq!(
|
||||
results[0].unwrap_i32(),
|
||||
types::Errno::Ok as i32,
|
||||
"double_int_return_float errno"
|
||||
);
|
||||
|
||||
// The actual result is in memory:
|
||||
let mem = shim_inst.get_memory("memory").unwrap();
|
||||
let mut result_bytes: [u8; 4] = [0, 0, 0, 0];
|
||||
mem.read(result_location as usize, &mut result_bytes)
|
||||
.unwrap();
|
||||
let result = f32::from_le_bytes(result_bytes);
|
||||
assert_eq!((input * 2) as f32, result);
|
||||
}
|
||||
|
||||
fn run<F: Future>(future: F) -> F::Output {
|
||||
let mut f = Pin::from(Box::new(future));
|
||||
let waker = dummy_waker();
|
||||
let mut cx = Context::from_waker(&waker);
|
||||
loop {
|
||||
match f.as_mut().poll(&mut cx) {
|
||||
Poll::Ready(val) => break val,
|
||||
Poll::Pending => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn dummy_waker() -> Waker {
|
||||
return unsafe { Waker::from_raw(clone(5 as *const _)) };
|
||||
|
||||
unsafe fn clone(ptr: *const ()) -> RawWaker {
|
||||
assert_eq!(ptr as usize, 5);
|
||||
const VTABLE: RawWakerVTable = RawWakerVTable::new(clone, wake, wake_by_ref, drop);
|
||||
RawWaker::new(ptr, &VTABLE)
|
||||
}
|
||||
|
||||
unsafe fn wake(ptr: *const ()) {
|
||||
assert_eq!(ptr as usize, 5);
|
||||
}
|
||||
|
||||
unsafe fn wake_by_ref(ptr: *const ()) {
|
||||
assert_eq!(ptr as usize, 5);
|
||||
}
|
||||
|
||||
unsafe fn drop(ptr: *const ()) {
|
||||
assert_eq!(ptr as usize, 5);
|
||||
}
|
||||
}
|
||||
fn async_store() -> wasmtime::Store {
|
||||
let engine = wasmtime::Engine::default();
|
||||
wasmtime::Store::new_async(&engine)
|
||||
}
|
||||
|
||||
// Wiggle expects the caller to have an exported memory. Wasmtime can only
|
||||
// provide this if the caller is a WebAssembly module, so we need to write
|
||||
// a shim module:
|
||||
fn shim_module(store: &wasmtime::Store) -> wasmtime::Module {
|
||||
wasmtime::Module::new(
|
||||
store.engine(),
|
||||
r#"
|
||||
(module
|
||||
(memory 1)
|
||||
(export "memory" (memory 0))
|
||||
(import "atoms" "int_float_args" (func $int_float_args (param i32 f32) (result i32)))
|
||||
(import "atoms" "double_int_return_float" (func $double_int_return_float (param i32 i32) (result i32)))
|
||||
|
||||
(func $int_float_args_shim (param i32 f32) (result i32)
|
||||
local.get 0
|
||||
local.get 1
|
||||
call $int_float_args
|
||||
)
|
||||
(func $double_int_return_float_shim (param i32 i32) (result i32)
|
||||
local.get 0
|
||||
local.get 1
|
||||
call $double_int_return_float
|
||||
)
|
||||
(export "int_float_args_shim" (func $int_float_args_shim))
|
||||
(export "double_int_return_float_shim" (func $double_int_return_float_shim))
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
Reference in New Issue
Block a user