wiggle: choose between &mut self and &self (#5428)

Previously, all Wiggle-generated traits were generated with `&mut self`
signatures. With the addition of the `mutable` configuration option to
`from_witx!` and `wasmtime_integration!`, one can disable this, emitting
instead traits that use `&self` (i.e., `mutable: false`). This change is
helpful for implementing wasi-threads: WASI implementations with
interior mutability will now be able to communitcate this to their
Wiggle-generated code.

The other side of this change is the `get_cx` closure passed to Wiggle's
generated `add_to_linker` function. When `mutability` is set to `true`
(default), the `get_cx` function takes a `&mut` data structure from the
store and returns a corresponding `&mut` reference, usually to a field
of the passed-in structure. When `mutability: false`, the `get_cx`
closure will still take a `&mut` data structure but now will return a
`&` reference.
This commit is contained in:
Andrew Brown
2022-12-13 14:38:47 -08:00
committed by GitHub
parent df923f18ca
commit 3ce896f69d
6 changed files with 53 additions and 10 deletions

View File

@@ -12,10 +12,13 @@ pub struct CodegenSettings {
pub errors: ErrorTransform, pub errors: ErrorTransform,
pub async_: AsyncConf, pub async_: AsyncConf,
pub wasmtime: bool, pub wasmtime: bool,
// Disabling this feature makes it possible to remove all of the tracing /// Disabling this feature makes it possible to remove all of the tracing
// code emitted in the Wiggle-generated code; this can be helpful while /// code emitted in the Wiggle-generated code; this can be helpful while
// inspecting the code (e.g., with `cargo expand`). /// inspecting the code (e.g., with `cargo expand`).
pub tracing: TracingConf, pub tracing: TracingConf,
/// Determine whether the context structure will use `&mut self` (true) or
/// simply `&self`.
pub mutable: bool,
} }
impl CodegenSettings { impl CodegenSettings {
pub fn new( pub fn new(
@@ -24,6 +27,7 @@ impl CodegenSettings {
doc: &Document, doc: &Document,
wasmtime: bool, wasmtime: bool,
tracing: &TracingConf, tracing: &TracingConf,
mutable: bool,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let errors = ErrorTransform::new(error_conf, doc)?; let errors = ErrorTransform::new(error_conf, doc)?;
Ok(Self { Ok(Self {
@@ -31,6 +35,7 @@ impl CodegenSettings {
async_: async_.clone(), async_: async_.clone(),
wasmtime, wasmtime,
tracing: tracing.clone(), tracing: tracing.clone(),
mutable,
}) })
} }
pub fn get_async(&self, module: &Module, func: &InterfaceFunc) -> Asyncness { pub fn get_async(&self, module: &Module, func: &InterfaceFunc) -> Asyncness {

View File

@@ -16,6 +16,7 @@ pub struct Config {
pub async_: AsyncConf, pub async_: AsyncConf,
pub wasmtime: bool, pub wasmtime: bool,
pub tracing: TracingConf, pub tracing: TracingConf,
pub mutable: bool,
} }
mod kw { mod kw {
@@ -25,6 +26,7 @@ mod kw {
syn::custom_keyword!(errors); syn::custom_keyword!(errors);
syn::custom_keyword!(target); syn::custom_keyword!(target);
syn::custom_keyword!(wasmtime); syn::custom_keyword!(wasmtime);
syn::custom_keyword!(mutable);
syn::custom_keyword!(tracing); syn::custom_keyword!(tracing);
syn::custom_keyword!(disable_for); syn::custom_keyword!(disable_for);
syn::custom_keyword!(trappable); syn::custom_keyword!(trappable);
@@ -37,6 +39,7 @@ pub enum ConfigField {
Async(AsyncConf), Async(AsyncConf),
Wasmtime(bool), Wasmtime(bool),
Tracing(TracingConf), Tracing(TracingConf),
Mutable(bool),
} }
impl Parse for ConfigField { impl Parse for ConfigField {
@@ -76,6 +79,10 @@ impl Parse for ConfigField {
input.parse::<kw::tracing>()?; input.parse::<kw::tracing>()?;
input.parse::<Token![:]>()?; input.parse::<Token![:]>()?;
Ok(ConfigField::Tracing(input.parse()?)) Ok(ConfigField::Tracing(input.parse()?))
} else if lookahead.peek(kw::mutable) {
input.parse::<kw::mutable>()?;
input.parse::<Token![:]>()?;
Ok(ConfigField::Mutable(input.parse::<syn::LitBool>()?.value))
} else { } else {
Err(lookahead.error()) Err(lookahead.error())
} }
@@ -89,6 +96,7 @@ impl Config {
let mut async_ = None; let mut async_ = None;
let mut wasmtime = None; let mut wasmtime = None;
let mut tracing = None; let mut tracing = None;
let mut mutable = None;
for f in fields { for f in fields {
match f { match f {
ConfigField::Witx(c) => { ConfigField::Witx(c) => {
@@ -121,6 +129,12 @@ impl Config {
} }
tracing = Some(c); tracing = Some(c);
} }
ConfigField::Mutable(c) => {
if mutable.is_some() {
return Err(Error::new(err_loc, "duplicate `mutable` field"));
}
mutable = Some(c);
}
} }
} }
Ok(Config { Ok(Config {
@@ -131,6 +145,7 @@ impl Config {
async_: async_.take().unwrap_or_default(), async_: async_.take().unwrap_or_default(),
wasmtime: wasmtime.unwrap_or(true), wasmtime: wasmtime.unwrap_or(true),
tracing: tracing.unwrap_or_default(), tracing: tracing.unwrap_or_default(),
mutable: mutable.unwrap_or(true),
}) })
} }
@@ -601,6 +616,12 @@ impl Parse for WasmtimeConfigField {
blocking: true, blocking: true,
functions: input.parse()?, functions: input.parse()?,
}))) })))
} else if lookahead.peek(kw::mutable) {
input.parse::<kw::mutable>()?;
input.parse::<Token![:]>()?;
Ok(WasmtimeConfigField::Core(ConfigField::Mutable(
input.parse::<syn::LitBool>()?.value,
)))
} else { } else {
Err(lookahead.error()) Err(lookahead.error())
} }

View File

@@ -77,6 +77,11 @@ fn _define_func(
function = #func_name function = #func_name
); );
); );
let ctx_type = if settings.mutable {
quote!(&'a mut)
} else {
quote!(&'a)
};
if settings.get_async(&module, &func).is_sync() { if settings.get_async(&module, &func).is_sync() {
let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) { let traced_body = if settings.tracing.enabled_for(&mod_name, &func_name) {
quote!( quote!(
@@ -91,8 +96,8 @@ fn _define_func(
( (
quote!( quote!(
#[allow(unreachable_code)] // deals with warnings in noreturn functions #[allow(unreachable_code)] // deals with warnings in noreturn functions
pub fn #ident( pub fn #ident<'a>(
ctx: &mut (impl #(#bounds)+*), ctx: #ctx_type (impl #(#bounds)+*),
memory: &dyn wiggle::GuestMemory, memory: &dyn wiggle::GuestMemory,
#(#abi_params),* #(#abi_params),*
) -> wiggle::anyhow::Result<#abi_ret> { ) -> wiggle::anyhow::Result<#abi_ret> {
@@ -122,7 +127,7 @@ fn _define_func(
quote!( quote!(
#[allow(unreachable_code)] // deals with warnings in noreturn functions #[allow(unreachable_code)] // deals with warnings in noreturn functions
pub fn #ident<'a>( pub fn #ident<'a>(
ctx: &'a mut (impl #(#bounds)+*), ctx: #ctx_type (impl #(#bounds)+*),
memory: &'a dyn wiggle::GuestMemory, memory: &'a dyn wiggle::GuestMemory,
#(#abi_params),* #(#abi_params),*
) -> impl std::future::Future<Output = wiggle::anyhow::Result<#abi_ret>> + 'a { ) -> impl std::future::Future<Output = wiggle::anyhow::Result<#abi_ret>> + 'a {

View File

@@ -81,10 +81,15 @@ pub fn define_module_trait(m: &Module, settings: &CodegenSettings) -> TokenStrea
quote!(async) quote!(async)
}; };
if is_anonymous { let self_ = if settings.mutable {
quote!(#asyncness fn #funcname(&mut self, #(#args),*) -> #result; ) quote!(&mut self)
} else { } else {
quote!(#asyncness fn #funcname<#lifetime>(&mut self, #(#args),*) -> #result;) quote!(&self)
};
if is_anonymous {
quote!(#asyncness fn #funcname(#self_, #(#args),*) -> #result; )
} else {
quote!(#asyncness fn #funcname<#lifetime>(#self_, #(#args),*) -> #result;)
} }
}); });

View File

@@ -46,11 +46,16 @@ pub fn link_module(
format_ident!("add_{}_to_linker", module_ident) format_ident!("add_{}_to_linker", module_ident)
}; };
let u = if settings.mutable {
quote!(&mut U)
} else {
quote!(&U)
};
quote! { quote! {
/// Adds all instance items to the specified `Linker`. /// Adds all instance items to the specified `Linker`.
pub fn #func_name<T, U>( pub fn #func_name<T, U>(
linker: &mut wiggle::wasmtime_crate::Linker<T>, linker: &mut wiggle::wasmtime_crate::Linker<T>,
get_cx: impl Fn(&mut T) -> &mut U + Send + Sync + Copy + 'static, get_cx: impl Fn(&mut T) -> #u + Send + Sync + Copy + 'static,
) -> wiggle::anyhow::Result<()> ) -> wiggle::anyhow::Result<()>
where where
U: #ctx_bound #send_bound U: #ctx_bound #send_bound

View File

@@ -157,6 +157,7 @@ pub fn from_witx(args: TokenStream) -> TokenStream {
&doc, &doc,
config.wasmtime, config.wasmtime,
&config.tracing, &config.tracing,
config.mutable,
) )
.expect("validating codegen settings"); .expect("validating codegen settings");
@@ -198,6 +199,7 @@ pub fn wasmtime_integration(args: TokenStream) -> TokenStream {
&doc, &doc,
true, true,
&config.c.tracing, &config.c.tracing,
config.c.mutable,
) )
.expect("validating codegen settings"); .expect("validating codegen settings");