diff --git a/crates/wasi/src/lib.rs b/crates/wasi/src/lib.rs index 5d29cce913..beb7b753ff 100644 --- a/crates/wasi/src/lib.rs +++ b/crates/wasi/src/lib.rs @@ -27,6 +27,11 @@ can be used to do name-based resolution." }, // Error to return when caller module is missing memory export: missing_memory: { wasi_common::wasi::Errno::Inval }, + // Don't use the wiggle generated code to implement proc_exit, we need to hook directly into + // the runtime there: + function_override: { + wasi_snapshot_preview1:proc_exit => wasi_proc_exit + } }); pub fn is_wasi_module(name: &str) -> bool { diff --git a/crates/wiggle/wasmtime/macro/src/config.rs b/crates/wiggle/wasmtime/macro/src/config.rs index b3d44d4492..837a7ba4d9 100644 --- a/crates/wiggle/wasmtime/macro/src/config.rs +++ b/crates/wiggle/wasmtime/macro/src/config.rs @@ -16,6 +16,7 @@ pub struct Config { pub ctx: CtxConf, pub instance: InstanceConf, pub missing_memory: MissingMemoryConf, + pub function_override: FunctionOverrideConf, } #[derive(Debug, Clone)] @@ -25,6 +26,7 @@ pub enum ConfigField { Ctx(CtxConf), Instance(InstanceConf), MissingMemory(MissingMemoryConf), + FunctionOverride(FunctionOverrideConf), } mod kw { @@ -36,6 +38,7 @@ mod kw { syn::custom_keyword!(name); syn::custom_keyword!(docs); syn::custom_keyword!(missing_memory); + syn::custom_keyword!(function_override); } impl Parse for ConfigField { @@ -65,6 +68,10 @@ impl Parse for ConfigField { input.parse::()?; input.parse::()?; Ok(ConfigField::MissingMemory(input.parse()?)) + } else if lookahead.peek(kw::function_override) { + input.parse::()?; + input.parse::()?; + Ok(ConfigField::FunctionOverride(input.parse()?)) } else { Err(lookahead.error()) } @@ -78,6 +85,7 @@ impl Config { let mut ctx = None; let mut instance = None; let mut missing_memory = None; + let mut function_override = None; for f in fields { match f { ConfigField::Target(c) => { @@ -110,6 +118,12 @@ impl Config { } missing_memory = Some(c); } + ConfigField::FunctionOverride(c) => { + if function_override.is_some() { + return Err(Error::new(err_loc, "duplicate `function_override` field")); + } + function_override = Some(c); + } } } Ok(Config { @@ -128,6 +142,7 @@ impl Config { missing_memory: missing_memory .take() .ok_or_else(|| Error::new(err_loc, "`missing_memory` field required"))?, + function_override: function_override.take().unwrap_or_default(), }) } @@ -245,3 +260,49 @@ impl Parse for MissingMemoryConf { }) } } + +#[derive(Debug, Clone, Default)] +pub struct FunctionOverrideConf { + pub funcs: Vec, +} +impl FunctionOverrideConf { + pub fn find(&self, module: &str, field: &str) -> Option<&Ident> { + self.funcs + .iter() + .find(|f| f.module == module && f.field == field) + .map(|f| &f.replacement) + } +} + +impl Parse for FunctionOverrideConf { + fn parse(input: ParseStream) -> Result { + let contents; + let _lbrace = braced!(contents in input); + let fields: Punctuated = + contents.parse_terminated(FunctionOverrideField::parse)?; + Ok(FunctionOverrideConf { + funcs: fields.into_iter().collect(), + }) + } +} + +#[derive(Debug, Clone)] +pub struct FunctionOverrideField { + pub module: String, + pub field: String, + pub replacement: Ident, +} +impl Parse for FunctionOverrideField { + fn parse(input: ParseStream) -> Result { + let module = input.parse::()?.to_string(); + input.parse::()?; + let field = input.parse::()?.to_string(); + input.parse::]>()?; + let replacement = input.parse::()?; + Ok(FunctionOverrideField { + module, + field, + replacement, + }) + } +} diff --git a/crates/wiggle/wasmtime/macro/src/lib.rs b/crates/wiggle/wasmtime/macro/src/lib.rs index 543d8bbe18..542e166497 100644 --- a/crates/wiggle/wasmtime/macro/src/lib.rs +++ b/crates/wiggle/wasmtime/macro/src/lib.rs @@ -6,7 +6,7 @@ use wiggle_generate::Names; mod config; -use config::{InstanceConf, MissingMemoryConf, TargetConf}; +use config::{FunctionOverrideConf, InstanceConf, MissingMemoryConf, TargetConf}; #[proc_macro] pub fn define_wasmtime_integration(args: TokenStream) -> TokenStream { @@ -23,6 +23,7 @@ pub fn define_wasmtime_integration(args: TokenStream) -> TokenStream { &config.target, &config.instance, &config.missing_memory, + &config.function_override, ) .into() } @@ -40,6 +41,7 @@ fn generate( target_conf: &TargetConf, instance_conf: &InstanceConf, missing_mem_conf: &MissingMemoryConf, + func_override_conf: &FunctionOverrideConf, ) -> TokenStream2 { let mut fields = Vec::new(); let mut get_exports = Vec::new(); @@ -63,12 +65,10 @@ fn generate( linker_add.push(quote! { linker.define(#module_name, #name, self.#name_ident.clone())?; }); - // `proc_exit` is special; it's essentially an unwinding primitive, - // so we implement it in the runtime rather than use the implementation - // in wasi-common. - if name == "proc_exit" { + + if let Some(func_override) = func_override_conf.find(module_name, name) { ctor_externs.push(quote! { - let #name_ident = wasmtime::Func::wrap(store, crate::wasi_proc_exit); + let #name_ident = wasmtime::Func::wrap(store, #func_override); }); continue; }