diff --git a/crates/wiggle/generate/src/codegen_settings.rs b/crates/wiggle/generate/src/codegen_settings.rs index 11391ed71f..b6fdf62f22 100644 --- a/crates/wiggle/generate/src/codegen_settings.rs +++ b/crates/wiggle/generate/src/codegen_settings.rs @@ -12,10 +12,13 @@ pub struct CodegenSettings { pub errors: ErrorTransform, pub async_: AsyncConf, pub wasmtime: bool, - // Disabling this feature makes it possible to remove all of the tracing - // code emitted in the Wiggle-generated code; this can be helpful while - // inspecting the code (e.g., with `cargo expand`). + /// Disabling this feature makes it possible to remove all of the tracing + /// code emitted in the Wiggle-generated code; this can be helpful while + /// inspecting the code (e.g., with `cargo expand`). pub tracing: TracingConf, + /// Determine whether the context structure will use `&mut self` (true) or + /// simply `&self`. + pub mutable: bool, } impl CodegenSettings { pub fn new( @@ -24,6 +27,7 @@ impl CodegenSettings { doc: &Document, wasmtime: bool, tracing: &TracingConf, + mutable: bool, ) -> Result { let errors = ErrorTransform::new(error_conf, doc)?; Ok(Self { @@ -31,6 +35,7 @@ impl CodegenSettings { async_: async_.clone(), wasmtime, tracing: tracing.clone(), + mutable, }) } pub fn get_async(&self, module: &Module, func: &InterfaceFunc) -> Asyncness { diff --git a/crates/wiggle/generate/src/config.rs b/crates/wiggle/generate/src/config.rs index 2e42a93ef5..3ee6045f99 100644 --- a/crates/wiggle/generate/src/config.rs +++ b/crates/wiggle/generate/src/config.rs @@ -16,6 +16,7 @@ pub struct Config { pub async_: AsyncConf, pub wasmtime: bool, pub tracing: TracingConf, + pub mutable: bool, } mod kw { @@ -25,6 +26,7 @@ mod kw { syn::custom_keyword!(errors); syn::custom_keyword!(target); syn::custom_keyword!(wasmtime); + syn::custom_keyword!(mutable); syn::custom_keyword!(tracing); syn::custom_keyword!(disable_for); syn::custom_keyword!(trappable); @@ -37,6 +39,7 @@ pub enum ConfigField { Async(AsyncConf), Wasmtime(bool), Tracing(TracingConf), + Mutable(bool), } impl Parse for ConfigField { @@ -76,6 +79,10 @@ impl Parse for ConfigField { input.parse::()?; input.parse::()?; Ok(ConfigField::Tracing(input.parse()?)) + } else if lookahead.peek(kw::mutable) { + input.parse::()?; + input.parse::()?; + Ok(ConfigField::Mutable(input.parse::()?.value)) } else { Err(lookahead.error()) } @@ -89,6 +96,7 @@ impl Config { let mut async_ = None; let mut wasmtime = None; let mut tracing = None; + let mut mutable = None; for f in fields { match f { ConfigField::Witx(c) => { @@ -121,6 +129,12 @@ impl Config { } tracing = Some(c); } + ConfigField::Mutable(c) => { + if mutable.is_some() { + return Err(Error::new(err_loc, "duplicate `mutable` field")); + } + mutable = Some(c); + } } } Ok(Config { @@ -131,6 +145,7 @@ impl Config { async_: async_.take().unwrap_or_default(), wasmtime: wasmtime.unwrap_or(true), tracing: tracing.unwrap_or_default(), + mutable: mutable.unwrap_or(true), }) } @@ -601,6 +616,12 @@ impl Parse for WasmtimeConfigField { blocking: true, functions: input.parse()?, }))) + } else if lookahead.peek(kw::mutable) { + input.parse::()?; + input.parse::()?; + Ok(WasmtimeConfigField::Core(ConfigField::Mutable( + input.parse::()?.value, + ))) } else { Err(lookahead.error()) } diff --git a/crates/wiggle/generate/src/funcs.rs b/crates/wiggle/generate/src/funcs.rs index 63177c3c73..0bf41f7ca7 100644 --- a/crates/wiggle/generate/src/funcs.rs +++ b/crates/wiggle/generate/src/funcs.rs @@ -77,6 +77,11 @@ fn _define_func( function = #func_name ); ); + let ctx_type = if settings.mutable { + quote!(&'a mut) + } else { + quote!(&'a) + }; if settings.get_async(&module, &func).is_sync() { let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) { quote!( @@ -91,8 +96,8 @@ fn _define_func( ( quote!( #[allow(unreachable_code)] // deals with warnings in noreturn functions - pub fn #ident( - ctx: &mut (impl #(#bounds)+*), + pub fn #ident<'a>( + ctx: #ctx_type (impl #(#bounds)+*), memory: &dyn wiggle::GuestMemory, #(#abi_params),* ) -> wiggle::anyhow::Result<#abi_ret> { @@ -122,7 +127,7 @@ fn _define_func( quote!( #[allow(unreachable_code)] // deals with warnings in noreturn functions pub fn #ident<'a>( - ctx: &'a mut (impl #(#bounds)+*), + ctx: #ctx_type (impl #(#bounds)+*), memory: &'a dyn wiggle::GuestMemory, #(#abi_params),* ) -> impl std::future::Future> + 'a { diff --git a/crates/wiggle/generate/src/module_trait.rs b/crates/wiggle/generate/src/module_trait.rs index c1dbfac21f..7a7b52fb0c 100644 --- a/crates/wiggle/generate/src/module_trait.rs +++ b/crates/wiggle/generate/src/module_trait.rs @@ -81,10 +81,15 @@ pub fn define_module_trait(m: &Module, settings: &CodegenSettings) -> TokenStrea quote!(async) }; - if is_anonymous { - quote!(#asyncness fn #funcname(&mut self, #(#args),*) -> #result; ) + let self_ = if settings.mutable { + quote!(&mut self) } else { - quote!(#asyncness fn #funcname<#lifetime>(&mut self, #(#args),*) -> #result;) + quote!(&self) + }; + if is_anonymous { + quote!(#asyncness fn #funcname(#self_, #(#args),*) -> #result; ) + } else { + quote!(#asyncness fn #funcname<#lifetime>(#self_, #(#args),*) -> #result;) } }); diff --git a/crates/wiggle/generate/src/wasmtime.rs b/crates/wiggle/generate/src/wasmtime.rs index fcab9da62e..68872a3d81 100644 --- a/crates/wiggle/generate/src/wasmtime.rs +++ b/crates/wiggle/generate/src/wasmtime.rs @@ -46,11 +46,16 @@ pub fn link_module( format_ident!("add_{}_to_linker", module_ident) }; + let u = if settings.mutable { + quote!(&mut U) + } else { + quote!(&U) + }; quote! { /// Adds all instance items to the specified `Linker`. pub fn #func_name( linker: &mut wiggle::wasmtime_crate::Linker, - get_cx: impl Fn(&mut T) -> &mut U + Send + Sync + Copy + 'static, + get_cx: impl Fn(&mut T) -> #u + Send + Sync + Copy + 'static, ) -> wiggle::anyhow::Result<()> where U: #ctx_bound #send_bound diff --git a/crates/wiggle/macro/src/lib.rs b/crates/wiggle/macro/src/lib.rs index 9ea2cf1048..4a0c469224 100644 --- a/crates/wiggle/macro/src/lib.rs +++ b/crates/wiggle/macro/src/lib.rs @@ -157,6 +157,7 @@ pub fn from_witx(args: TokenStream) -> TokenStream { &doc, config.wasmtime, &config.tracing, + config.mutable, ) .expect("validating codegen settings"); @@ -198,6 +199,7 @@ pub fn wasmtime_integration(args: TokenStream) -> TokenStream { &doc, true, &config.c.tracing, + config.c.mutable, ) .expect("validating codegen settings");