diff --git a/crates/component-macro/src/bindgen.rs b/crates/component-macro/src/bindgen.rs index 76e06c80d7..2a813ef5f7 100644 --- a/crates/component-macro/src/bindgen.rs +++ b/crates/component-macro/src/bindgen.rs @@ -1,4 +1,5 @@ use proc_macro2::{Span, TokenStream}; +use std::collections::HashMap; use std::path::{Path, PathBuf}; use syn::parse::{Error, Parse, ParseStream, Result}; use syn::punctuated::Punctuated; @@ -77,6 +78,8 @@ impl Parse for Config { Opt::Async(val) => opts.async_ = val, Opt::TrappableErrorType(val) => opts.trappable_error_type = val, Opt::DuplicateIfNecessary(val) => opts.duplicate_if_necessary = val, + Opt::OnlyInterfaces(val) => opts.only_interfaces = val, + Opt::With(val) => opts.with.extend(val), } } } else { @@ -133,6 +136,8 @@ mod kw { syn::custom_keyword!(trappable_error_type); syn::custom_keyword!(world); syn::custom_keyword!(duplicate_if_necessary); + syn::custom_keyword!(only_interfaces); + syn::custom_keyword!(with); } enum Opt { @@ -143,6 +148,8 @@ enum Opt { Async(bool), TrappableErrorType(Vec), DuplicateIfNecessary(bool), + OnlyInterfaces(bool), + With(HashMap), } impl Parse for Opt { @@ -191,6 +198,18 @@ impl Parse for Opt { }) .collect(), )) + } else if l.peek(kw::only_interfaces) { + input.parse::()?; + input.parse::()?; + Ok(Opt::OnlyInterfaces(input.parse::()?.value)) + } else if l.peek(kw::with) { + input.parse::()?; + input.parse::()?; + let contents; + let _lbrace = braced!(contents in input); + let fields: Punctuated<(String, String), Token![,]> = + contents.parse_terminated(with_field_parse)?; + Ok(Opt::With(HashMap::from_iter(fields.into_iter()))) } else { Err(l.error()) } @@ -219,3 +238,46 @@ fn trappable_error_field_parse(input: ParseStream<'_>) -> Result<(String, String let rust_type = input.parse::()?.to_string(); Ok((interface, type_, rust_type)) } + +fn with_field_parse(input: ParseStream<'_>) -> Result<(String, String)> { + let interface = input.parse::()?.value(); + input.parse::()?; + let start = input.span(); + let path = input.parse::()?; + + // It's not possible for the segments of a path to be empty + let span = start + .join(path.segments.last().unwrap().ident.span()) + .unwrap_or(start); + + let mut buf = String::new(); + let append = |buf: &mut String, segment: syn::PathSegment| -> Result<()> { + if segment.arguments != syn::PathArguments::None { + return Err(Error::new( + span, + "Module path must not contain angles or parens", + )); + } + + buf.push_str(&segment.ident.to_string()); + + Ok(()) + }; + + if path.leading_colon.is_some() { + buf.push_str("::"); + } + + let mut segments = path.segments.into_iter(); + + if let Some(segment) = segments.next() { + append(&mut buf, segment)?; + } + + for segment in segments { + buf.push_str("::"); + append(&mut buf, segment)?; + } + + Ok((interface, buf)) +} diff --git a/crates/component-macro/tests/codegen.rs b/crates/component-macro/tests/codegen.rs index 0384ab4c14..3bcead0489 100644 --- a/crates/component-macro/tests/codegen.rs +++ b/crates/component-macro/tests/codegen.rs @@ -21,9 +21,17 @@ macro_rules! gentest { duplicate_if_necessary: true, }); } + mod interfaces_only { + wasmtime::component::bindgen!({ + path: $path, + world: $name, + only_interfaces: true, + }); + } } // ... }; + } component_macro_test_helpers::foreach!(gentest); diff --git a/crates/wasmtime/src/component/mod.rs b/crates/wasmtime/src/component/mod.rs index e5cc6d4969..3dd7b740dd 100644 --- a/crates/wasmtime/src/component/mod.rs +++ b/crates/wasmtime/src/component/mod.rs @@ -292,6 +292,20 @@ pub(crate) use self::store::ComponentStoreData; /// interface::ErrorType: RustErrorType, /// }, /// +/// // Restrict the code generated to what's needed for the imported +/// // interfaces of the world file provided. This option is most useful +/// // in conjunction with the `with` option that permits remapping of +/// // interface names in generated code. +/// only_interfaces: true, +/// +/// // Remap interface names to module names, imported from elswhere. +/// // Using this option will prevent any code from being generated +/// // for the names mentioned in the mapping, assuming instead that the +/// // names mentioned come from a previous use of the `bindgen!` macro +/// // with `only_interfaces: true`. +/// with: { +/// "a": somewhere::else::a, +/// }, /// }); /// ``` /// diff --git a/crates/wit-bindgen/src/lib.rs b/crates/wit-bindgen/src/lib.rs index 3d516d7d18..f564a5a956 100644 --- a/crates/wit-bindgen/src/lib.rs +++ b/crates/wit-bindgen/src/lib.rs @@ -25,6 +25,15 @@ mod source; mod types; use source::Source; +struct InterfaceName { + /// True when this interface name has been remapped through the use of `with` in the `bindgen!` + /// macro invocation. + remapped: bool, + + /// The string name for this interface. + name: String, +} + #[derive(Default)] struct Wasmtime { src: Source, @@ -33,7 +42,7 @@ struct Wasmtime { exports: Exports, types: Types, sizes: SizeAlign, - interface_names: HashMap, + interface_names: HashMap, } enum Import { @@ -66,6 +75,13 @@ pub struct Opts { /// WIT type if necessary, for example if it's used as both an import and an /// export. pub duplicate_if_necessary: bool, + + /// Whether or not to generate code for only the interfaces of this wit file or not. + pub only_interfaces: bool, + + /// Remapping of interface names to rust module names. + /// TODO: is there a better type to use for the value of this map? + pub with: HashMap, } #[derive(Debug, Clone)] @@ -95,14 +111,37 @@ impl Opts { } impl Wasmtime { + fn name_interface(&mut self, id: InterfaceId, name: String) -> bool { + let entry = if let Some(remapped_name) = self.opts.with.get(&name) { + InterfaceName { + remapped: true, + name: remapped_name.clone(), + } + } else { + InterfaceName { + remapped: false, + name, + } + }; + + let remapped = entry.remapped; + self.interface_names.insert(id, entry); + + remapped + } + fn generate(&mut self, resolve: &Resolve, id: WorldId) -> String { self.types.analyze(resolve, id); let world = &resolve.worlds[id]; for (name, import) in world.imports.iter() { - self.import(resolve, name, import); + if !self.opts.only_interfaces || matches!(import, WorldItem::Interface(_)) { + self.import(resolve, name, import); + } } for (name, export) in world.exports.iter() { - self.export(resolve, name, export); + if !self.opts.only_interfaces || matches!(export, WorldItem::Interface(_)) { + self.export(resolve, name, export); + } } self.finish(resolve, id) } @@ -119,7 +158,9 @@ impl Wasmtime { Import::Function { sig, add_to_linker } } WorldItem::Interface(id) => { - gen.gen.interface_names.insert(*id, snake.clone()); + if gen.gen.name_interface(*id, snake.clone()) { + return; + } gen.current_interface = Some(*id); gen.types(*id); gen.generate_trappable_error_types(TypeOwner::Interface(*id)); @@ -166,7 +207,7 @@ impl Wasmtime { } WorldItem::Type(_) => unreachable!(), WorldItem::Interface(id) => { - gen.gen.interface_names.insert(*id, snake.clone()); + gen.gen.name_interface(*id, snake.clone()); gen.current_interface = Some(*id); gen.types(*id); gen.generate_trappable_error_types(TypeOwner::Interface(*id)); @@ -246,7 +287,7 @@ impl Wasmtime { assert!(prev.is_none()); } - fn finish(&mut self, resolve: &Resolve, world: WorldId) -> String { + fn build_struct(&mut self, resolve: &Resolve, world: WorldId) { let camel = to_rust_upper_camel_case(&resolve.worlds[world].name); uwriteln!(self.src, "pub struct {camel} {{"); for (name, (ty, _)) in self.exports.fields.iter() { @@ -327,6 +368,12 @@ impl Wasmtime { uwriteln!(self.src, "}}"); // close `impl {camel}` uwriteln!(self.src, "}};"); // close `const _: () = ... + } + + fn finish(&mut self, resolve: &Resolve, world: WorldId) -> String { + if !self.opts.only_interfaces { + self.build_struct(resolve, world) + } let mut src = mem::take(&mut self.src); if self.opts.rustfmt { @@ -1398,8 +1445,8 @@ impl<'a> RustGenerator<'a> for InterfaceGenerator<'a> { match self.current_interface { Some(id) if id == interface => None, _ => { - let name = &self.gen.interface_names[&interface]; - Some(if self.current_interface.is_some() { + let InterfaceName { remapped, name } = &self.gen.interface_names[&interface]; + Some(if self.current_interface.is_some() && !remapped { format!("super::{name}") } else { name.clone() diff --git a/tests/all/component_model/bindgen/results.rs b/tests/all/component_model/bindgen/results.rs index 0626af86f8..19376a64a2 100644 --- a/tests/all/component_model/bindgen/results.rs +++ b/tests/all/component_model/bindgen/results.rs @@ -633,3 +633,121 @@ mod variant_error { Ok(()) } } + +mod with_remapping { + use super::*; + + mod interfaces { + wasmtime::component::bindgen!({ + inline: " + default world result-playground { + import imports: interface { + empty-error: func(a: float64) -> result + } + + export empty-error: func(a: float64) -> result + }", + only_interfaces: true, + }); + } + + wasmtime::component::bindgen!({ + inline: " + default world result-playground { + import imports: interface { + empty-error: func(a: float64) -> result + } + + export empty-error: func(a: float64) -> result + }", + with: { + "imports": interfaces::imports, + }, + }); + + #[test] + fn run() -> Result<(), Error> { + let engine = engine(); + let component = Component::new( + &engine, + r#" + (component + (import "imports" (instance $i + (export "empty-error" (func (param "a" float64) (result (result float64)))) + )) + (core module $libc + (memory (export "memory") 1) + ) + (core instance $libc (instantiate $libc)) + (core module $m + (import "" "core_empty_error" (func $f (param f64 i32))) + (import "libc" "memory" (memory 0)) + (func (export "core_empty_error_export") (param f64) (result i32) + (call $f (local.get 0) (i32.const 8)) + (i32.const 8) + ) + ) + (core func $core_empty_error + (canon lower (func $i "empty-error") (memory $libc "memory")) + ) + (core instance $i (instantiate $m + (with "" (instance (export "core_empty_error" (func $core_empty_error)))) + (with "libc" (instance $libc)) + )) + (func $f_empty_error + (export "empty-error") + (param "a" float64) + (result (result float64)) + (canon lift (core func $i "core_empty_error_export") (memory $libc "memory")) + ) + ) + "#, + )?; + + #[derive(Default)] + struct MyImports {} + + impl interfaces::imports::Host for MyImports { + fn empty_error(&mut self, a: f64) -> Result, Error> { + if a == 0.0 { + Ok(Ok(a)) + } else if a == 1.0 { + Ok(Err(())) + } else { + Err(anyhow!("empty_error: trap")) + } + } + } + + let mut linker = Linker::new(&engine); + interfaces::imports::add_to_linker(&mut linker, |f: &mut MyImports| f)?; + + let mut store = Store::new(&engine, MyImports::default()); + let (results, _) = ResultPlayground::instantiate(&mut store, &component, &linker)?; + + assert_eq!( + results + .call_empty_error(&mut store, 0.0) + .expect("no trap") + .expect("no error returned"), + 0.0 + ); + + results + .call_empty_error(&mut store, 1.0) + .expect("no trap") + .err() + .expect("() error returned"); + + let e = results + .call_empty_error(&mut store, 2.0) + .err() + .expect("trap"); + assert_eq!( + format!("{}", e.source().expect("trap message is stored in source")), + "empty_error: trap" + ); + + Ok(()) + } +}