diff --git a/crates/component-util/src/lib.rs b/crates/component-util/src/lib.rs index 2c4a2a461d..827b168cb6 100644 --- a/crates/component-util/src/lib.rs +++ b/crates/component-util/src/lib.rs @@ -123,11 +123,35 @@ pub const REALLOC_AND_FREE: &str = r#" ;; save the current value of `$last` as the return value global.get $last - local.tee $ret + local.set $ret + + ;; bump our pointer + (global.set $last + (i32.add + (global.get $last) + (local.get $new_size))) + + ;; while `memory.size` is less than `$last`, grow memory + ;; by one page + (loop $loop + (if + (i32.lt_u + (i32.mul (memory.size) (i32.const 65536)) + (global.get $last)) + (then + i32.const 1 + memory.grow + ;; test to make sure growth succeeded + i32.const -1 + i32.eq + if unreachable end + + br $loop))) + ;; ensure anything necessary is set to valid data by spraying a bit ;; pattern that is invalid - global.get $last + local.get $ret i32.const 0xde local.get $new_size memory.fill @@ -142,10 +166,6 @@ pub const REALLOC_AND_FREE: &str = r#" memory.copy end - ;; bump our pointer - (global.set $last - (i32.add - (global.get $last) - (local.get $new_size))) + local.get $ret ) "#; diff --git a/crates/environ/src/fact.rs b/crates/environ/src/fact.rs index 3cb16a68c8..19f8a899be 100644 --- a/crates/environ/src/fact.rs +++ b/crates/environ/src/fact.rs @@ -19,10 +19,13 @@ //! that. use crate::component::dfg::CoreDef; -use crate::component::{Adapter, AdapterOptions, ComponentTypes, StringEncoding, TypeFuncIndex}; -use crate::{FuncIndex, GlobalIndex, MemoryIndex, PrimaryMap}; +use crate::component::{ + Adapter, AdapterOptions as AdapterOptionsDfg, ComponentTypes, InterfaceType, StringEncoding, + TypeFuncIndex, +}; +use crate::fact::transcode::Transcoder; +use crate::{EntityRef, FuncIndex, GlobalIndex, MemoryIndex, PrimaryMap}; use std::collections::HashMap; -use std::mem; use wasm_encoder::*; mod core_types; @@ -50,26 +53,28 @@ pub struct Module<'a> { /// Final list of imports that this module ended up using, in the same order /// as the imports in the import section. imports: Vec, - /// Intern'd imports and what index they were assigned. - imported: HashMap, - imported_memories: PrimaryMap, + /// Intern'd imports and what index they were assigned. Note that this map + /// covers all the index spaces for imports, not just one. + imported: HashMap, + /// Intern'd transcoders and what index they were assigned. + imported_transcoders: HashMap, // Current status of index spaces from the imports generated so far. - core_funcs: u32, - core_memories: u32, - core_globals: u32, + imported_funcs: PrimaryMap>, + imported_memories: PrimaryMap, + imported_globals: PrimaryMap, - /// Adapters which will be compiled once they're all registered. - adapters: Vec, + funcs: PrimaryMap, + translate_mem_funcs: HashMap<(InterfaceType, InterfaceType, Options, Options), FunctionId>, } struct AdapterData { /// Export name of this adapter name: String, /// Options specified during the `canon lift` operation - lift: Options, + lift: AdapterOptions, /// Options specified during the `canon lower` operation - lower: Options, + lower: AdapterOptions, /// The core wasm function that this adapter will be calling (the original /// function that was `canon lift`'d) callee: FuncIndex, @@ -78,14 +83,38 @@ struct AdapterData { called_as_export: bool, } -struct Options { +/// Configuration options which apply at the "global adapter" level. +/// +/// These options are typically unique per-adapter and generally aren't needed +/// when translating recursive types within an adapter. +struct AdapterOptions { + /// The ascribed type of this adapter. ty: TypeFuncIndex, - string_encoding: StringEncoding, + /// The global that represents the instance flags for where this adapter + /// came from. flags: GlobalIndex, - memory64: bool, - memory: Option, - realloc: Option, + /// The configured post-return function, if any. post_return: Option, + /// Other, more general, options configured. + options: Options, +} + +/// This type is split out of `AdapterOptions` and is specifically used to +/// deduplicate translation functions within a module. Consequently this has +/// as few fields as possible to minimize the number of functions generated +/// within an adapter module. +#[derive(PartialEq, Eq, Hash, Copy, Clone)] +struct Options { + /// The encoding that strings use from this adapter. + string_encoding: StringEncoding, + /// Whether or not the `memory` field, if present, is a 64-bit memory. + memory64: bool, + /// An optionally-specified memory where values may travel through for + /// types like lists. + memory: Option, + /// An optionally-specified function to be used to allocate space for + /// types such as strings as they go into a module. + realloc: Option, } enum Context { @@ -102,12 +131,13 @@ impl<'a> Module<'a> { core_types: Default::default(), core_imports: Default::default(), imported: Default::default(), - adapters: Default::default(), imports: Default::default(), + imported_transcoders: Default::default(), + imported_funcs: PrimaryMap::new(), imported_memories: PrimaryMap::new(), - core_funcs: 0, - core_memories: 0, - core_globals: 0, + imported_globals: PrimaryMap::new(), + funcs: PrimaryMap::new(), + translate_mem_funcs: HashMap::new(), } } @@ -128,7 +158,7 @@ impl<'a> Module<'a> { // Import the core wasm function which was lifted using its appropriate // signature since the exported function this adapter generates will // call the lifted function. - let signature = self.signature(&lift, Context::Lift); + let signature = self.types.signature(&lift, Context::Lift); let ty = self .core_types .function(&signature.params, &signature.results); @@ -141,19 +171,24 @@ impl<'a> Module<'a> { self.import_func("post_return", name, ty, func.clone()) }); - self.adapters.push(AdapterData { - name: name.to_string(), - lift, - lower, - callee, - // FIXME(#4185) should be plumbed and handled as part of the new - // reentrance rules not yet implemented here. - called_as_export: true, - }); + // This will internally create the adapter as specified and append + // anything necessary to `self.funcs`. + trampoline::compile( + self, + &AdapterData { + name: name.to_string(), + lift, + lower, + callee, + // FIXME(#4185) should be plumbed and handled as part of the new + // reentrance rules not yet implemented here. + called_as_export: true, + }, + ); } - fn import_options(&mut self, ty: TypeFuncIndex, options: &AdapterOptions) -> Options { - let AdapterOptions { + fn import_options(&mut self, ty: TypeFuncIndex, options: &AdapterOptionsDfg) -> AdapterOptions { + let AdapterOptionsDfg { instance, string_encoding, memory, @@ -192,23 +227,24 @@ impl<'a> Module<'a> { let ty = self.core_types.function(&[ptr, ptr, ptr, ptr], &[ptr]); self.import_func("realloc", "", ty, func.clone()) }); - Options { + + AdapterOptions { ty, - string_encoding: *string_encoding, flags, - memory64: *memory64, - memory, - realloc, post_return: None, + options: Options { + string_encoding: *string_encoding, + memory64: *memory64, + memory, + realloc, + }, } } fn import_func(&mut self, module: &str, name: &str, ty: u32, def: CoreDef) -> FuncIndex { - FuncIndex::from_u32( - self.import(module, name, EntityType::Function(ty), def, |m| { - &mut m.core_funcs - }), - ) + self.import(module, name, EntityType::Function(ty), def, |m| { + &mut m.imported_funcs + }) } fn import_global( @@ -218,9 +254,9 @@ impl<'a> Module<'a> { ty: GlobalType, def: CoreDef, ) -> GlobalIndex { - GlobalIndex::from_u32(self.import(module, name, EntityType::Global(ty), def, |m| { - &mut m.core_globals - })) + self.import(module, name, EntityType::Global(ty), def, |m| { + &mut m.imported_globals + }) } fn import_memory( @@ -230,82 +266,113 @@ impl<'a> Module<'a> { ty: MemoryType, def: CoreDef, ) -> MemoryIndex { - MemoryIndex::from_u32(self.import(module, name, EntityType::Memory(ty), def, |m| { - &mut m.core_memories - })) + self.import(module, name, EntityType::Memory(ty), def, |m| { + &mut m.imported_memories + }) } - fn import( + fn import>( &mut self, module: &str, name: &str, ty: EntityType, def: CoreDef, - new: impl FnOnce(&mut Self) -> &mut u32, - ) -> u32 { + map: impl FnOnce(&mut Self) -> &mut PrimaryMap, + ) -> K { if let Some(prev) = self.imported.get(&def) { - return *prev; + return K::new(*prev); } - let cnt = new(self); - *cnt += 1; - let ret = *cnt - 1; + let idx = map(self).push(def.clone().into()); self.core_imports.import(module, name, ty); - self.imported.insert(def.clone(), ret); - if let EntityType::Memory(_) = ty { - self.imported_memories.push(def.clone()); - } + self.imported.insert(def.clone(), idx.index()); self.imports.push(Import::CoreDef(def)); - ret + idx + } + + fn import_transcoder(&mut self, transcoder: transcode::Transcoder) -> FuncIndex { + *self + .imported_transcoders + .entry(transcoder) + .or_insert_with(|| { + // Add the import to the core wasm import section... + let name = transcoder.name(); + let ty = transcoder.ty(&mut self.core_types); + self.core_imports.import("transcode", &name, ty); + + // ... and also record the metadata for what this import + // corresponds to. + let from = self.imported_memories[transcoder.from_memory].clone(); + let to = self.imported_memories[transcoder.to_memory].clone(); + self.imports.push(Import::Transcode { + op: transcoder.op, + from, + from64: transcoder.from_memory64, + to, + to64: transcoder.to_memory64, + }); + + self.imported_funcs.push(None) + }) } /// Encodes this module into a WebAssembly binary. pub fn encode(&mut self) -> Vec { - let mut types = mem::take(&mut self.core_types); - let mut transcoders = transcode::Transcoders::new(self.core_funcs); - let mut adapter_funcs = Vec::new(); - for adapter in self.adapters.iter() { - adapter_funcs.push(trampoline::compile( - self, - &mut types, - &mut transcoders, - adapter, - )); - } - - // If any string transcoding imports were needed add imported items - // associated with them. - for (module, name, ty, transcoder) in transcoders.imports() { - self.core_imports.import(module, name, ty); - let from = self.imported_memories[transcoder.from_memory].clone(); - let to = self.imported_memories[transcoder.to_memory].clone(); - self.imports.push(Import::Transcode { - op: transcoder.op, - from, - from64: transcoder.from_memory64, - to, - to64: transcoder.to_memory64, - }); - self.core_funcs += 1; - } - - // Now that all functions are known as well as all imports the actual - // bodies of all adapters are assembled into a final module. + // Build the function/export sections of the wasm module in a first pass + // which will assign a final `FuncIndex` to all functions defined in + // `self.funcs`. let mut funcs = FunctionSection::new(); - let mut code = CodeSection::new(); let mut exports = ExportSection::new(); - let mut traps = traps::TrapSection::default(); - for (adapter, (function, func_traps)) in self.adapters.iter().zip(adapter_funcs) { - let idx = self.core_funcs + funcs.len(); - exports.export(&adapter.name, ExportKind::Func, idx); + let mut id_to_index = PrimaryMap::::new(); + for (id, func) in self.funcs.iter() { + assert!(func.filled_in); + let idx = FuncIndex::from_u32(self.imported_funcs.next_key().as_u32() + id.as_u32()); + let id2 = id_to_index.push(idx); + assert_eq!(id2, id); - let signature = self.signature(&adapter.lower, Context::Lower); - let ty = types.function(&signature.params, &signature.results); - funcs.function(ty); + funcs.function(func.ty); - code.raw(&function); - traps.append(idx, func_traps); + if let Some(name) = &func.export { + exports.export(name, ExportKind::Func, idx.as_u32()); + } } - self.core_types = types; + + // With all functions numbered the fragments of the body of each + // function can be assigned into one final adapter function. + let mut code = CodeSection::new(); + let mut traps = traps::TrapSection::default(); + for (id, func) in self.funcs.iter() { + let mut func_traps = Vec::new(); + let mut body = Vec::new(); + + // Encode all locals used for this function + func.locals.len().encode(&mut body); + for (count, ty) in func.locals.iter() { + count.encode(&mut body); + ty.encode(&mut body); + } + + // Then encode each "chunk" of a body which may have optional traps + // specified within it. Traps get offset by the current length of + // the body and otherwise our `Call` instructions are "relocated" + // here to the final function index. + for chunk in func.body.iter() { + match chunk { + Body::Raw(code, traps) => { + let start = body.len(); + body.extend_from_slice(code); + for (offset, trap) in traps { + func_traps.push((start + offset, *trap)); + } + } + Body::Call(id) => { + Instruction::Call(id_to_index[*id].as_u32()).encode(&mut body); + } + } + } + code.raw(&body); + traps.append(id_to_index[id].as_u32(), func_traps); + } + let traps = traps.finish(); let mut result = wasm_encoder::Module::new(); @@ -367,3 +434,82 @@ impl Options { } } } + +/// Temporary index which is not the same as `FuncIndex`. +/// +/// This represents the nth generated function in the adapter module where the +/// final index of the function is not known at the time of generation since +/// more imports may be discovered (specifically string transcoders). +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +struct FunctionId(u32); +cranelift_entity::entity_impl!(FunctionId); + +/// A generated function to be added to an adapter module. +/// +/// At least one function is created per-adapter and dependeing on the type +/// hierarchy multiple functions may be generated per-adapter. +struct Function { + /// Whether or not the `body` has been finished. + /// + /// Functions are added to a `Module` before they're defined so this is used + /// to assert that the function was in fact actually filled in by the + /// time we reach `Module::encode`. + filled_in: bool, + + /// The type signature that this function has, as an index into the core + /// wasm type index space of the generated adapter module. + ty: u32, + + /// The locals that are used by this function, organized by the number of + /// types of each local. + locals: Vec<(u32, ValType)>, + + /// If specified, the export name of this function. + export: Option, + + /// The contents of the function. + /// + /// See `Body` for more information, and the `Vec` here represents the + /// concatentation of all the `Body` fragments. + body: Vec, +} + +/// Representation of a fragment of the body of a core wasm function generated +/// for adapters. +/// +/// This variant comes in one of two flavors: +/// +/// 1. First a `Raw` variant is used to contain general instructions for the +/// wasm function. This is populated by `Compiler::instruction` primarily. +/// This also comes with a list of traps. and the byte offset within the +/// first vector of where the trap information applies to. +/// +/// 2. A `Call` instruction variant for a `FunctionId` where the final +/// `FuncIndex` isn't known until emission time. +/// +/// The purpose of this representation is the `Body::Call` variant. This can't +/// be encoded as an instruction when it's generated due to not knowing the +/// final index of the function being called. During `Module::encode`, however, +/// all indices are known and `Body::Call` is turned into a final +/// `Instruction::Call`. +/// +/// One other possible representation in the future would be to encode a `Call` +/// instruction with a 5-byte leb to fill in later, but for now this felt +/// easier to represent. A 5-byte leb may be more efficient at compile-time if +/// necessary, however. +enum Body { + Raw(Vec, Vec<(usize, traps::Trap)>), + Call(FunctionId), +} + +impl Function { + fn new(export: Option, ty: u32) -> Function { + Function { + filled_in: false, + ty, + locals: Vec::new(), + export, + body: Vec::new(), + } + } +} diff --git a/crates/environ/src/fact/signature.rs b/crates/environ/src/fact/signature.rs index 27313f13c9..f6b8b3fb73 100644 --- a/crates/environ/src/fact/signature.rs +++ b/crates/environ/src/fact/signature.rs @@ -1,8 +1,9 @@ //! Size, align, and flattening information about component model types. -use crate::component::{InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; -use crate::fact::{Context, Module, Options}; +use crate::component::{ComponentTypes, InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; +use crate::fact::{AdapterOptions, Context, Options}; use wasm_encoder::ValType; +use wasmtime_component_util::{DiscriminantSize, FlagsSize}; /// Metadata about a core wasm signature which is created for a component model /// signature. @@ -27,25 +28,25 @@ pub(crate) fn align_to(n: usize, align: usize) -> usize { (n + (align - 1)) & !(align - 1) } -impl Module<'_> { +impl ComponentTypes { /// Calculates the core wasm function signature for the component function /// type specified within `Context`. /// /// This is used to generate the core wasm signatures for functions that are /// imported (matching whatever was `canon lift`'d) and functions that are /// exported (matching the generated function from `canon lower`). - pub(super) fn signature(&self, options: &Options, context: Context) -> Signature { - let ty = &self.types[options.ty]; - let ptr_ty = options.ptr(); + pub(super) fn signature(&self, options: &AdapterOptions, context: Context) -> Signature { + let ty = &self[options.ty]; + let ptr_ty = options.options.ptr(); - let mut params = self.flatten_types(options, ty.params.iter().map(|(_, ty)| *ty)); + let mut params = self.flatten_types(&options.options, ty.params.iter().map(|(_, ty)| *ty)); let mut params_indirect = false; if params.len() > MAX_FLAT_PARAMS { params = vec![ptr_ty]; params_indirect = true; } - let mut results = self.flatten_types(options, [ty.result]); + let mut results = self.flatten_types(&options.options, [ty.result]); let mut results_indirect = false; if results.len() > MAX_FLAT_RESULTS { results_indirect = true; @@ -108,17 +109,17 @@ impl Module<'_> { dst.push(opts.ptr()); } InterfaceType::Record(r) => { - for field in self.types[*r].fields.iter() { + for field in self[*r].fields.iter() { self.push_flat(opts, &field.ty, dst); } } InterfaceType::Tuple(t) => { - for ty in self.types[*t].types.iter() { + for ty in self[*t].types.iter() { self.push_flat(opts, ty, dst); } } InterfaceType::Flags(f) => { - let flags = &self.types[*f]; + let flags = &self[*f]; let nflags = align_to(flags.names.len(), 32) / 32; for _ in 0..nflags { dst.push(ValType::I32); @@ -127,13 +128,13 @@ impl Module<'_> { InterfaceType::Enum(_) => dst.push(ValType::I32), InterfaceType::Option(t) => { dst.push(ValType::I32); - self.push_flat(opts, &self.types[*t], dst); + self.push_flat(opts, &self[*t], dst); } InterfaceType::Variant(t) => { dst.push(ValType::I32); let pos = dst.len(); let mut tmp = Vec::new(); - for case in self.types[*t].cases.iter() { + for case in self[*t].cases.iter() { self.push_flat_variant(opts, &case.ty, pos, &mut tmp, dst); } } @@ -141,13 +142,13 @@ impl Module<'_> { dst.push(ValType::I32); let pos = dst.len(); let mut tmp = Vec::new(); - for ty in self.types[*t].types.iter() { + for ty in self[*t].types.iter() { self.push_flat_variant(opts, ty, pos, &mut tmp, dst); } } InterfaceType::Expected(t) => { dst.push(ValType::I32); - let e = &self.types[*t]; + let e = &self[*t]; let pos = dst.len(); let mut tmp = Vec::new(); self.push_flat_variant(opts, &e.ok, pos, &mut tmp, dst); @@ -208,26 +209,26 @@ impl Module<'_> { } InterfaceType::Record(r) => { - self.record_size_align(opts, self.types[*r].fields.iter().map(|f| &f.ty)) + self.record_size_align(opts, self[*r].fields.iter().map(|f| &f.ty)) } - InterfaceType::Tuple(t) => self.record_size_align(opts, self.types[*t].types.iter()), - InterfaceType::Flags(f) => match self.types[*f].names.len() { - n if n <= 8 => (1, 1), - n if n <= 16 => (2, 2), - n if n <= 32 => (4, 4), - n => (4 * (align_to(n, 32) / 32), 4), + InterfaceType::Tuple(t) => self.record_size_align(opts, self[*t].types.iter()), + InterfaceType::Flags(f) => match FlagsSize::from_count(self[*f].names.len()) { + FlagsSize::Size0 => (0, 1), + FlagsSize::Size1 => (1, 1), + FlagsSize::Size2 => (2, 2), + FlagsSize::Size4Plus(n) => (n * 4, 4), }, - InterfaceType::Enum(t) => self.discrim_size_align(self.types[*t].names.len()), + InterfaceType::Enum(t) => self.discrim_size_align(self[*t].names.len()), InterfaceType::Option(t) => { - let ty = &self.types[*t]; + let ty = &self[*t]; self.variant_size_align(opts, [&InterfaceType::Unit, ty].into_iter()) } InterfaceType::Variant(t) => { - self.variant_size_align(opts, self.types[*t].cases.iter().map(|c| &c.ty)) + self.variant_size_align(opts, self[*t].cases.iter().map(|c| &c.ty)) } - InterfaceType::Union(t) => self.variant_size_align(opts, self.types[*t].types.iter()), + InterfaceType::Union(t) => self.variant_size_align(opts, self[*t].types.iter()), InterfaceType::Expected(t) => { - let e = &self.types[*t]; + let e = &self[*t]; self.variant_size_align(opts, [&e.ok, &e.err].into_iter()) } } @@ -260,14 +261,18 @@ impl Module<'_> { payload_size = payload_size.max(csize); align = align.max(calign); } - (align_to(discrim_size, align) + payload_size, align) + ( + align_to(align_to(discrim_size, align) + payload_size, align), + align, + ) } fn discrim_size_align<'a>(&self, cases: usize) -> (usize, usize) { - match cases { - n if n <= u8::MAX as usize => (1, 1), - n if n <= u16::MAX as usize => (2, 2), - _ => (4, 4), + match DiscriminantSize::from_count(cases) { + Some(DiscriminantSize::Size1) => (1, 1), + Some(DiscriminantSize::Size2) => (2, 2), + Some(DiscriminantSize::Size4) => (4, 4), + None => unreachable!(), } } } diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index b5337229a2..bc9095129f 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -16,16 +16,15 @@ //! can be somewhat arbitrary, an intentional decision. use crate::component::{ - InterfaceType, StringEncoding, TypeEnumIndex, TypeExpectedIndex, TypeFlagsIndex, - TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, TypeVariantIndex, - FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + ComponentTypes, InterfaceType, StringEncoding, TypeEnumIndex, TypeExpectedIndex, + TypeFlagsIndex, TypeInterfaceIndex, TypeRecordIndex, TypeTupleIndex, TypeUnionIndex, + TypeVariantIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; -use crate::fact::core_types::CoreTypes; use crate::fact::signature::{align_to, Signature}; -use crate::fact::transcode::{FixedEncoding as FE, Transcode, Transcoder, Transcoders}; +use crate::fact::transcode::{FixedEncoding as FE, Transcode, Transcoder}; use crate::fact::traps::Trap; -use crate::fact::{AdapterData, Context, Module, Options}; -use crate::GlobalIndex; +use crate::fact::{AdapterData, Body, Context, Function, FunctionId, Module, Options}; +use crate::{FuncIndex, GlobalIndex}; use std::collections::HashMap; use std::mem; use std::ops::Range; @@ -36,28 +35,13 @@ const MAX_STRING_BYTE_LENGTH: u32 = 1 << 31; const UTF16_TAG: u32 = 1 << 31; struct Compiler<'a, 'b> { - /// The module that the adapter will eventually be inserted into. - module: &'a Module<'a>, - - /// The type section of `module` - types: &'b mut CoreTypes, - - /// Imported functions to transcode between various string encodings. - transcoders: &'b mut Transcoders, - - /// Metadata about the adapter that is being compiled. - adapter: &'a AdapterData, + types: &'a ComponentTypes, + module: &'b mut Module<'a>, + result: FunctionId, /// The encoded WebAssembly function body so far, not including locals. code: Vec, - /// Generated locals that this function will use. - /// - /// The first entry in the tuple is the number of locals and the second - /// entry is the type of those locals. This is pushed during compilation as - /// locals become necessary. - locals: Vec<(u32, ValType)>, - /// Total number of locals generated so far. nlocals: u32, @@ -66,36 +50,90 @@ struct Compiler<'a, 'b> { /// well. traps: Vec<(usize, Trap)>, - /// The function signature of the lowered half of this trampoline, or the - /// signature of the function that's being generated. - lower_sig: &'a Signature, - - /// The function signature of the lifted half of this trampoline, or the - /// signature of the function that's imported the trampoline will call. - lift_sig: &'a Signature, + /// Indicates whether this call to `translate` is a "top level" on where + /// it's the first call from the root of the generated function. This is + /// used as a heuristic to know when to split helpers out to a separate + /// function. + top_level_translate: bool, } -pub(super) fn compile( - module: &Module<'_>, - types: &mut CoreTypes, - transcoders: &mut Transcoders, - adapter: &AdapterData, -) -> (Vec, Vec<(usize, Trap)>) { - let lower_sig = &module.signature(&adapter.lower, Context::Lower); - let lift_sig = &module.signature(&adapter.lift, Context::Lift); +pub(super) fn compile(module: &mut Module<'_>, adapter: &AdapterData) { + let lower_sig = module.types.signature(&adapter.lower, Context::Lower); + let lift_sig = module.types.signature(&adapter.lift, Context::Lift); + let ty = module + .core_types + .function(&lower_sig.params, &lower_sig.results); + let result = module + .funcs + .push(Function::new(Some(adapter.name.clone()), ty)); Compiler { + types: module.types, module, - types, - adapter, - transcoders, code: Vec::new(), - locals: Vec::new(), nlocals: lower_sig.params.len() as u32, traps: Vec::new(), - lower_sig, - lift_sig, + result, + top_level_translate: true, } - .compile() + .compile_adapter(adapter, &lower_sig, &lift_sig) +} + +/// Compiles a helper function which is used to translate `src` to `dst` +/// in-memory. +/// +/// The generated function takes two arguments: the source pointer and +/// destination pointer. The conversion operation is configured by the +/// `src_opts` and `dst_opts` specified as well. +fn compile_translate_mem( + module: &mut Module<'_>, + src: InterfaceType, + src_opts: &Options, + dst: InterfaceType, + dst_opts: &Options, +) -> FunctionId { + // If a helper for this translation has already been generated then reuse + // that. Note that this is key to this function where by doing this it + // prevents an exponentially sized output given any particular input type. + let key = (src, dst, *src_opts, *dst_opts); + if module.translate_mem_funcs.contains_key(&key) { + return module.translate_mem_funcs[&key]; + } + + // Generate a fresh `Function` with a unique id for what we're about to + // generate. + let ty = module + .core_types + .function(&[src_opts.ptr(), dst_opts.ptr()], &[]); + let result = module.funcs.push(Function::new(None, ty)); + module.translate_mem_funcs.insert(key, result); + let mut compiler = Compiler { + types: module.types, + module, + code: Vec::new(), + nlocals: 2, + traps: Vec::new(), + result, + top_level_translate: true, + }; + // This function only does one thing which is to translate between memory, + // so only one call to `translate` is necessary. Note that the `addr_local` + // values come from the function arguments. + compiler.translate( + &src, + &Source::Memory(Memory { + opts: src_opts, + addr_local: 0, + offset: 0, + }), + &dst, + &Destination::Memory(Memory { + opts: dst_opts, + addr_local: 1, + offset: 0, + }), + ); + compiler.finish(); + result } /// Possible ways that a interface value is represented in the core wasm @@ -150,19 +188,24 @@ struct Memory<'a> { } impl Compiler<'_, '_> { - fn compile(&mut self) -> (Vec, Vec<(usize, Trap)>) { + fn compile_adapter( + mut self, + adapter: &AdapterData, + lower_sig: &Signature, + lift_sig: &Signature, + ) { // Check the instance flags required for this trampoline. // // This inserts the initial check required by `canon_lower` that the // caller instance can be left and additionally checks the // flags on the callee if necessary whether it can be entered. - self.trap_if_not_flag(self.adapter.lower.flags, FLAG_MAY_LEAVE, Trap::CannotLeave); - if self.adapter.called_as_export { - self.trap_if_not_flag(self.adapter.lift.flags, FLAG_MAY_ENTER, Trap::CannotEnter); - self.set_flag(self.adapter.lift.flags, FLAG_MAY_ENTER, false); + self.trap_if_not_flag(adapter.lower.flags, FLAG_MAY_LEAVE, Trap::CannotLeave); + if adapter.called_as_export { + self.trap_if_not_flag(adapter.lift.flags, FLAG_MAY_ENTER, Trap::CannotEnter); + self.set_flag(adapter.lift.flags, FLAG_MAY_ENTER, false); } else if self.module.debug { self.assert_not_flag( - self.adapter.lift.flags, + adapter.lift.flags, FLAG_MAY_ENTER, "may_enter should be unset", ); @@ -180,23 +223,22 @@ impl Compiler<'_, '_> { // TODO: if translation doesn't actually call any functions in either // instance then there's no need to set/clear the flag here and that can // be optimized away. - self.set_flag(self.adapter.lift.flags, FLAG_MAY_LEAVE, false); - let param_locals = self - .lower_sig + self.set_flag(adapter.lift.flags, FLAG_MAY_LEAVE, false); + let param_locals = lower_sig .params .iter() .enumerate() .map(|(i, ty)| (i as u32, *ty)) .collect::>(); - self.translate_params(¶m_locals); - self.set_flag(self.adapter.lift.flags, FLAG_MAY_LEAVE, true); + self.translate_params(adapter, ¶m_locals); + self.set_flag(adapter.lift.flags, FLAG_MAY_LEAVE, true); // With all the arguments on the stack the actual target function is // now invoked. The core wasm results of the function are then placed // into locals for result translation afterwards. - self.instruction(Call(self.adapter.callee.as_u32())); - let mut result_locals = Vec::with_capacity(self.lift_sig.results.len()); - for ty in self.lift_sig.results.iter().rev() { + self.instruction(Call(adapter.callee.as_u32())); + let mut result_locals = Vec::with_capacity(lift_sig.results.len()); + for ty in lift_sig.results.iter().rev() { let local = self.gen_local(*ty); self.instruction(LocalSet(local)); result_locals.push((local, *ty)); @@ -211,77 +253,75 @@ impl Compiler<'_, '_> { // // TODO: like above the management of the `MAY_LEAVE` flag can probably // be elided here for "simple" results. - self.set_flag(self.adapter.lower.flags, FLAG_MAY_LEAVE, false); - self.translate_results(¶m_locals, &result_locals); - self.set_flag(self.adapter.lower.flags, FLAG_MAY_LEAVE, true); + self.set_flag(adapter.lower.flags, FLAG_MAY_LEAVE, false); + self.translate_results(adapter, ¶m_locals, &result_locals); + self.set_flag(adapter.lower.flags, FLAG_MAY_LEAVE, true); // And finally post-return state is handled here once all results/etc // are all translated. - if let Some(func) = self.adapter.lift.post_return { + if let Some(func) = adapter.lift.post_return { for (result, _) in result_locals.iter() { self.instruction(LocalGet(*result)); } self.instruction(Call(func.as_u32())); } - if self.adapter.called_as_export { - self.set_flag(self.adapter.lift.flags, FLAG_MAY_ENTER, true); + if adapter.called_as_export { + self.set_flag(adapter.lift.flags, FLAG_MAY_ENTER, true); } self.finish() } - fn translate_params(&mut self, param_locals: &[(u32, ValType)]) { - let src_tys = &self.module.types[self.adapter.lower.ty].params; + fn translate_params(&mut self, adapter: &AdapterData, param_locals: &[(u32, ValType)]) { + let src_tys = &self.types[adapter.lower.ty].params; let src_tys = src_tys.iter().map(|(_, ty)| *ty).collect::>(); - let dst_tys = &self.module.types[self.adapter.lift.ty].params; + let dst_tys = &self.types[adapter.lift.ty].params; let dst_tys = dst_tys.iter().map(|(_, ty)| *ty).collect::>(); + let lift_opts = &adapter.lift.options; + let lower_opts = &adapter.lower.options; // TODO: handle subtyping assert_eq!(src_tys.len(), dst_tys.len()); let src_flat = self - .module - .flatten_types(&self.adapter.lower, src_tys.iter().copied()); - let dst_flat = self - .module - .flatten_types(&self.adapter.lift, dst_tys.iter().copied()); + .types + .flatten_types(lower_opts, src_tys.iter().copied()); + let dst_flat = self.types.flatten_types(lift_opts, dst_tys.iter().copied()); let src = if src_flat.len() <= MAX_FLAT_PARAMS { Source::Stack(Stack { locals: ¶m_locals[..src_flat.len()], - opts: &self.adapter.lower, + opts: lower_opts, }) } else { // If there are too many parameters then that means the parameters // are actually a tuple stored in linear memory addressed by the // first parameter local. let (addr, ty) = param_locals[0]; - assert_eq!(ty, self.adapter.lower.ptr()); + assert_eq!(ty, lower_opts.ptr()); let align = src_tys .iter() - .map(|t| self.module.align(&self.adapter.lower, t)) + .map(|t| self.types.align(lower_opts, t)) .max() .unwrap_or(1); - Source::Memory(self.memory_operand(&self.adapter.lower, addr, align)) + Source::Memory(self.memory_operand(lower_opts, addr, align)) }; let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { - Destination::Stack(&dst_flat, &self.adapter.lift) + Destination::Stack(&dst_flat, lift_opts) } else { // If there are too many parameters then space is allocated in the // destination module for the parameters via its `realloc` function. - let (size, align) = self - .module - .record_size_align(&self.adapter.lift, dst_tys.iter()); + let (size, align) = self.types.record_size_align(lift_opts, dst_tys.iter()); let size = MallocSize::Const(size); - Destination::Memory(self.malloc(&self.adapter.lift, size, align)) + Destination::Memory(self.malloc(lift_opts, size, align)) }; let srcs = src - .record_field_srcs(self.module, src_tys.iter().copied()) + .record_field_srcs(self.types, src_tys.iter().copied()) .zip(src_tys.iter()); let dsts = dst - .record_field_dsts(self.module, dst_tys.iter().copied()) + .record_field_dsts(self.types, dst_tys.iter().copied()) .zip(dst_tys.iter()); for ((src, src_ty), (dst, dst_ty)) in srcs.zip(dsts) { self.translate(&src_ty, &src, &dst_ty, &dst); @@ -297,42 +337,45 @@ impl Compiler<'_, '_> { fn translate_results( &mut self, + adapter: &AdapterData, param_locals: &[(u32, ValType)], result_locals: &[(u32, ValType)], ) { - let src_ty = self.module.types[self.adapter.lift.ty].result; - let dst_ty = self.module.types[self.adapter.lower.ty].result; + let src_ty = self.types[adapter.lift.ty].result; + let dst_ty = self.types[adapter.lower.ty].result; + let lift_opts = &adapter.lift.options; + let lower_opts = &adapter.lower.options; - let src_flat = self.module.flatten_types(&self.adapter.lift, [src_ty]); - let dst_flat = self.module.flatten_types(&self.adapter.lower, [dst_ty]); + let src_flat = self.types.flatten_types(lift_opts, [src_ty]); + let dst_flat = self.types.flatten_types(lower_opts, [dst_ty]); let src = if src_flat.len() <= MAX_FLAT_RESULTS { Source::Stack(Stack { locals: result_locals, - opts: &self.adapter.lift, + opts: lift_opts, }) } else { // The original results to read from in this case come from the // return value of the function itself. The imported function will // return a linear memory address at which the values can be read // from. - let align = self.module.align(&self.adapter.lift, &src_ty); + let align = self.types.align(lift_opts, &src_ty); assert_eq!(result_locals.len(), 1); let (addr, ty) = result_locals[0]; - assert_eq!(ty, self.adapter.lift.ptr()); - Source::Memory(self.memory_operand(&self.adapter.lift, addr, align)) + assert_eq!(ty, lift_opts.ptr()); + Source::Memory(self.memory_operand(lift_opts, addr, align)) }; let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { - Destination::Stack(&dst_flat, &self.adapter.lower) + Destination::Stack(&dst_flat, lower_opts) } else { // This is slightly different than `translate_params` where the // return pointer was provided by the caller of this function // meaning the last parameter local is a pointer into linear memory. - let align = self.module.align(&self.adapter.lower, &dst_ty); + let align = self.types.align(lower_opts, &dst_ty); let (addr, ty) = *param_locals.last().expect("no retptr"); - assert_eq!(ty, self.adapter.lower.ptr()); - Destination::Memory(self.memory_operand(&self.adapter.lower, addr, align)) + assert_eq!(ty, lower_opts.ptr()); + Destination::Memory(self.memory_operand(lower_opts, addr, align)) }; self.translate(&src_ty, &src, &dst_ty, &dst); @@ -351,6 +394,107 @@ impl Compiler<'_, '_> { if let Destination::Memory(mem) = dst { self.assert_aligned(dst_ty, mem); } + + // Classify the source type as "primitive" or not as a heuristic to + // whether the translation should be split out into a helper function. + let src_primitive = match src_ty { + InterfaceType::Unit + | InterfaceType::Bool + | InterfaceType::U8 + | InterfaceType::S8 + | InterfaceType::U16 + | InterfaceType::S16 + | InterfaceType::U32 + | InterfaceType::S32 + | InterfaceType::U64 + | InterfaceType::S64 + | InterfaceType::Float32 + | InterfaceType::Float64 + | InterfaceType::Char + | InterfaceType::Flags(_) => true, + + InterfaceType::String + | InterfaceType::List(_) + | InterfaceType::Record(_) + | InterfaceType::Tuple(_) + | InterfaceType::Variant(_) + | InterfaceType::Union(_) + | InterfaceType::Enum(_) + | InterfaceType::Option(_) + | InterfaceType::Expected(_) => false, + }; + let top_level = mem::replace(&mut self.top_level_translate, false); + + // Use a number of heuristics to determine whether this translation + // should be split out into a helper function rather than translated + // inline. The goal of this heuristic is to avoid a function that is + // exponential in the size of a type. For example if everything + // were translated inline then this could get arbitrarily large + // + // (type $level0 (list u8)) + // (type $level1 (expected $level0 $level0)) + // (type $level2 (expected $level1 $level1)) + // (type $level3 (expected $level2 $level2)) + // (type $level4 (expected $level3 $level3)) + // ;; ... + // + // If everything we inlined then translation of `$level0` would appear + // in 2^n different locations depending on the depth of the type. By + // splitting out the translation to a helper function, though, it + // means there could be one function for each level, keeping the size + // of translation on par with the size of the module itself. + // + // The heuristics which go into this splitting currently are: + // + // * Both the source and destination must be memory. This skips "top + // level" translation for adapters where arguments/results come from + // direct parameters or get placed on the stack. + // + // * Primitive types are skipped here since they have no need to be + // split out. This is for types like integers and floats. + // + // * The "top level" of a function is also skipped. That basically + // means that the first call to `translate` will never split out + // a helper function (since if we're already in a helper function + // that could cause infinite recursion in the wasm). Otherwise + // this keeps the top-level list of types in adapters nice and inline + // too while only possibly considering splitting out deeper types. + // + // This heuristic may need tweaking over time naturally as more modules + // in the wild are seen and performance measurements are taken. For now + // this keeps the fuzzers happy by avoiding exponentially-sized output + // given an input. + if let (Source::Memory(src), Destination::Memory(dst)) = (src, dst) { + if !src_primitive && !top_level { + // Compile the helper function which will translate the source + // type to the destination type. The two parameters to this + // function are the source/destination pointers which are + // calculated here to pass through. Our own function then + // grows a `Body::Call` to the function generated. Note that + // `Body::Call` is used here instead of `Instruction::Call` + // because we don't know the final index of the generated + // function yet. It's filled in at the end of adapter module + // translation. + let helper = + compile_translate_mem(self.module, *src_ty, src.opts, *dst_ty, dst.opts); + + // TODO: overflow checks? + self.instruction(LocalGet(src.addr_local)); + if src.offset != 0 { + self.ptr_uconst(src.opts, src.offset); + self.ptr_add(src.opts); + } + self.instruction(LocalGet(dst.addr_local)); + if dst.offset != 0 { + self.ptr_uconst(dst.opts, dst.offset); + self.ptr_add(dst.opts); + } + self.flush_code(); + self.module.funcs[self.result].body.push(Body::Call(helper)); + self.top_level_translate = true; + return; + } + } match src_ty { InterfaceType::Unit => self.translate_unit(src, dst_ty, dst), InterfaceType::Bool => self.translate_bool(src, dst_ty, dst), @@ -376,6 +520,8 @@ impl Compiler<'_, '_> { InterfaceType::Option(t) => self.translate_option(*t, src, dst_ty, dst), InterfaceType::Expected(t) => self.translate_expected(*t, src, dst_ty, dst), } + + self.top_level_translate = top_level; } fn translate_unit(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { @@ -504,7 +650,7 @@ impl Compiler<'_, '_> { fn convert_u32_mask(&mut self, src: &Source<'_>, dst: &Destination<'_>, mask: u32) { self.push_dst_addr(dst); match src { - Source::Memory(mem) => self.i32_load16u(mem), + Source::Memory(mem) => self.i32_load(mem), Source::Stack(stack) => self.stack_get(stack, ValType::I32), } if mask != 0xffffffff { @@ -854,7 +1000,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); dst } @@ -923,7 +1069,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); self.instruction(LocalGet(dst_byte_len)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); self.instruction(LocalSet(dst.len)); let src_len_tmp = self.gen_local(src.opts.ptr()); self.instruction(LocalSet(src_len_tmp)); @@ -978,7 +1124,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(dst_byte_len)); self.instruction(LocalGet(dst.len)); self.ptr_sub(dst.opts); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); // Add the second result, the amount of destination units encoded, // to `dst_len` so it's an accurate reflection of the final size of @@ -1075,7 +1221,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); self.instruction(LocalSet(dst.len)); // If the number of code units returned by transcode is not @@ -1137,14 +1283,18 @@ impl Compiler<'_, '_> { } }; - self.validate_string_inbounds(src, dst_byte_len); + let src_byte_len = self.gen_local(src.opts.ptr()); + self.convert_src_len_to_dst(dst_byte_len, dst.opts.ptr(), src.opts.ptr()); + self.instruction(LocalSet(src_byte_len)); + + self.validate_string_inbounds(src, src.len); self.validate_string_inbounds(&dst, dst_byte_len); let transcode = self.transcoder(src, &dst, Transcode::Utf16ToCompactProbablyUtf16); self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode)); + self.instruction(Call(transcode.as_u32())); self.instruction(LocalSet(dst.len)); // Assert that the untagged code unit length is the same as the @@ -1222,7 +1372,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(src.ptr)); self.instruction(LocalGet(src.len)); self.instruction(LocalGet(dst.ptr)); - self.instruction(Call(transcode_latin1)); + self.instruction(Call(transcode_latin1.as_u32())); self.instruction(LocalSet(dst.len)); let src_len_tmp = self.gen_local(src.opts.ptr()); self.instruction(LocalSet(src_len_tmp)); @@ -1289,7 +1439,7 @@ impl Compiler<'_, '_> { self.instruction(LocalGet(dst.ptr)); self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst.opts.ptr()); self.instruction(LocalGet(dst.len)); - self.instruction(Call(transcode_utf16)); + self.instruction(Call(transcode_utf16.as_u32())); self.instruction(LocalSet(dst.len)); // If the returned number of code units written to the destination @@ -1340,17 +1490,19 @@ impl Compiler<'_, '_> { self.instruction(End); } - fn transcoder(&mut self, src: &WasmString<'_>, dst: &WasmString<'_>, op: Transcode) -> u32 { - self.transcoders.import( - self.types, - Transcoder { - from_memory: src.opts.memory.unwrap(), - from_memory64: src.opts.memory64, - to_memory: dst.opts.memory.unwrap(), - to_memory64: dst.opts.memory64, - op, - }, - ) + fn transcoder( + &mut self, + src: &WasmString<'_>, + dst: &WasmString<'_>, + op: Transcode, + ) -> FuncIndex { + self.module.import_transcoder(Transcoder { + from_memory: src.opts.memory.unwrap(), + from_memory64: src.opts.memory64, + to_memory: dst.opts.memory.unwrap(), + to_memory64: dst.opts.memory64, + op, + }) } fn validate_string_inbounds(&mut self, s: &WasmString<'_>, byte_len: u32) { @@ -1386,7 +1538,7 @@ impl Compiler<'_, '_> { self.instruction(LocalTee(tmp)); self.instruction(LocalGet(s.ptr)); self.ptr_lt_u(s.opts); - self.ptr_br_if(s.opts, 0); + self.instruction(BrIf(0)); self.instruction(LocalGet(tmp)); } @@ -1408,15 +1560,15 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_element_ty = &self.module.types[src_ty]; + let src_element_ty = &self.types[src_ty]; let dst_element_ty = match dst_ty { - InterfaceType::List(r) => &self.module.types[*r], + InterfaceType::List(r) => &self.types[*r], _ => panic!("expected a list"), }; let src_opts = src.opts(); let dst_opts = dst.opts(); - let (src_size, src_align) = self.module.size_align(src_opts, src_element_ty); - let (dst_size, dst_align) = self.module.size_align(dst_opts, dst_element_ty); + let (src_size, src_align) = self.types.size_align(src_opts, src_element_ty); + let (dst_size, dst_align) = self.types.size_align(dst_opts, dst_element_ty); // Load the pointer/length of this list into temporary locals. These // will be referenced a good deal so this just makes it easier to deal @@ -1791,9 +1943,9 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Record(r) => &self.module.types[*r], + InterfaceType::Record(r) => &self.types[*r], _ => panic!("expected a record"), }; @@ -1805,7 +1957,7 @@ impl Compiler<'_, '_> { // fields' names let mut src_fields = HashMap::new(); for (i, src) in src - .record_field_srcs(self.module, src_ty.fields.iter().map(|f| f.ty)) + .record_field_srcs(self.types, src_ty.fields.iter().map(|f| f.ty)) .enumerate() { let field = &src_ty.fields[i]; @@ -1821,7 +1973,7 @@ impl Compiler<'_, '_> { // // TODO: should that lookup be fallible with subtyping? for (i, dst) in dst - .record_field_dsts(self.module, dst_ty.fields.iter().map(|f| f.ty)) + .record_field_dsts(self.types, dst_ty.fields.iter().map(|f| f.ty)) .enumerate() { let field = &dst_ty.fields[i]; @@ -1837,9 +1989,9 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Flags(r) => &self.module.types[*r], + InterfaceType::Flags(r) => &self.types[*r], _ => panic!("expected a record"), }; @@ -1863,8 +2015,8 @@ impl Compiler<'_, '_> { self.convert_u16_mask(src, dst, mask); } FlagsSize::Size4Plus(n) => { - let srcs = src.record_field_srcs(self.module, (0..n).map(|_| InterfaceType::U32)); - let dsts = dst.record_field_dsts(self.module, (0..n).map(|_| InterfaceType::U32)); + let srcs = src.record_field_srcs(self.types, (0..n).map(|_| InterfaceType::U32)); + let dsts = dst.record_field_dsts(self.types, (0..n).map(|_| InterfaceType::U32)); for (i, (src, dst)) in srcs.zip(dsts).enumerate() { let mask = if i == n - 1 && (cnt % 32 != 0) { (1 << (cnt % 32)) - 1 @@ -1884,9 +2036,9 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Tuple(t) => &self.module.types[*t], + InterfaceType::Tuple(t) => &self.types[*t], _ => panic!("expected a tuple"), }; @@ -1894,10 +2046,10 @@ impl Compiler<'_, '_> { assert_eq!(src_ty.types.len(), dst_ty.types.len()); let srcs = src - .record_field_srcs(self.module, src_ty.types.iter().copied()) + .record_field_srcs(self.types, src_ty.types.iter().copied()) .zip(src_ty.types.iter()); let dsts = dst - .record_field_dsts(self.module, dst_ty.types.iter().copied()) + .record_field_dsts(self.types, dst_ty.types.iter().copied()) .zip(dst_ty.types.iter()); for ((src, src_ty), (dst, dst_ty)) in srcs.zip(dsts) { self.translate(src_ty, &src, dst_ty, &dst); @@ -1911,14 +2063,14 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Variant(t) => &self.module.types[*t], + InterfaceType::Variant(t) => &self.types[*t], _ => panic!("expected a variant"), }; - let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap(); - let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap(); + let src_info = VariantInfo::new(self.types, src.opts(), src_ty.cases.iter().map(|c| c.ty)); + let dst_info = VariantInfo::new(self.types, dst.opts(), dst_ty.cases.iter().map(|c| c.ty)); let iter = src_ty.cases.iter().enumerate().map(|(src_i, src_case)| { let dst_i = dst_ty @@ -1936,7 +2088,7 @@ impl Compiler<'_, '_> { dst_ty: &dst_case.ty, } }); - self.convert_variant(src, src_disc_size, dst, dst_disc_size, iter); + self.convert_variant(src, &src_info, dst, &dst_info, iter); } fn translate_union( @@ -1946,18 +2098,20 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Union(t) => &self.module.types[*t], + InterfaceType::Union(t) => &self.types[*t], _ => panic!("expected an option"), }; assert_eq!(src_ty.types.len(), dst_ty.types.len()); + let src_info = VariantInfo::new(self.types, src.opts(), src_ty.types.iter().copied()); + let dst_info = VariantInfo::new(self.types, dst.opts(), dst_ty.types.iter().copied()); self.convert_variant( src, - DiscriminantSize::Size1, + &src_info, dst, - DiscriminantSize::Size1, + &dst_info, src_ty .types .iter() @@ -1982,18 +2136,28 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Enum(t) => &self.module.types[*t], + InterfaceType::Enum(t) => &self.types[*t], _ => panic!("expected an option"), }; + let src_info = VariantInfo::new( + self.types, + src.opts(), + src_ty.names.iter().map(|_| InterfaceType::Unit), + ); + let dst_info = VariantInfo::new( + self.types, + dst.opts(), + dst_ty.names.iter().map(|_| InterfaceType::Unit), + ); let unit = &InterfaceType::Unit; self.convert_variant( src, - DiscriminantSize::from_count(src_ty.names.len()).unwrap(), + &src_info, dst, - DiscriminantSize::from_count(dst_ty.names.len()).unwrap(), + &dst_info, src_ty.names.iter().enumerate().map(|(src_i, src_name)| { let dst_i = dst_ty.names.iter().position(|n| n == src_name).unwrap(); let src_i = u32::try_from(src_i).unwrap(); @@ -2015,17 +2179,20 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Option(t) => &self.module.types[*t], + InterfaceType::Option(t) => &self.types[*t], _ => panic!("expected an option"), }; + let src_info = VariantInfo::new(self.types, src.opts(), [InterfaceType::Unit, *src_ty]); + let dst_info = VariantInfo::new(self.types, dst.opts(), [InterfaceType::Unit, *dst_ty]); + self.convert_variant( src, - DiscriminantSize::Size1, + &src_info, dst, - DiscriminantSize::Size1, + &dst_info, [ VariantCase { src_i: 0, @@ -2051,17 +2218,20 @@ impl Compiler<'_, '_> { dst_ty: &InterfaceType, dst: &Destination, ) { - let src_ty = &self.module.types[src_ty]; + let src_ty = &self.types[src_ty]; let dst_ty = match dst_ty { - InterfaceType::Expected(t) => &self.module.types[*t], + InterfaceType::Expected(t) => &self.types[*t], _ => panic!("expected an expected"), }; + let src_info = VariantInfo::new(self.types, src.opts(), [src_ty.ok, src_ty.err]); + let dst_info = VariantInfo::new(self.types, dst.opts(), [dst_ty.ok, dst_ty.err]); + self.convert_variant( src, - DiscriminantSize::Size1, + &src_info, dst, - DiscriminantSize::Size1, + &dst_info, [ VariantCase { src_i: 0, @@ -2083,9 +2253,9 @@ impl Compiler<'_, '_> { fn convert_variant<'a>( &mut self, src: &Source<'_>, - src_disc_size: DiscriminantSize, + src_info: &VariantInfo, dst: &Destination, - dst_disc_size: DiscriminantSize, + dst_info: &VariantInfo, src_cases: impl ExactSizeIterator>, ) { // The outermost block is special since it has the result type of the @@ -2095,7 +2265,7 @@ impl Compiler<'_, '_> { 0 => BlockType::Empty, 1 => BlockType::Result(dst_flat[0]), _ => { - let ty = self.types.function(&[], &dst_flat); + let ty = self.module.core_types.function(&[], &dst_flat); BlockType::FunctionType(ty) } }, @@ -2120,7 +2290,7 @@ impl Compiler<'_, '_> { // Load the discriminant match src { Source::Stack(s) => self.stack_get(&s.slice(0..1), ValType::I32), - Source::Memory(mem) => match src_disc_size { + Source::Memory(mem) => match src_info.size { DiscriminantSize::Size1 => self.i32_load8u(mem), DiscriminantSize::Size2 => self.i32_load16u(mem), DiscriminantSize::Size4 => self.i32_load(mem), @@ -2158,7 +2328,7 @@ impl Compiler<'_, '_> { self.instruction(I32Const(dst_i as i32)); match dst { Destination::Stack(stack, _) => self.stack_set(&stack[..1], ValType::I32), - Destination::Memory(mem) => match dst_disc_size { + Destination::Memory(mem) => match dst_info.size { DiscriminantSize::Size1 => self.i32_store8(mem), DiscriminantSize::Size2 => self.i32_store16(mem), DiscriminantSize::Size4 => self.i32_store(mem), @@ -2167,8 +2337,8 @@ impl Compiler<'_, '_> { // Translate the payload of this case using the various types from // the dst/src. - let src_payload = src.payload_src(self.module, src_disc_size, src_ty); - let dst_payload = dst.payload_dst(self.module, dst_disc_size, dst_ty); + let src_payload = src.payload_src(self.types, src_info, src_ty); + let dst_payload = dst.payload_dst(self.types, dst_info, dst_ty); self.translate(src_ty, &src_payload, dst_ty, &dst_payload); // If the results of this translation were placed on the stack then @@ -2251,7 +2421,7 @@ impl Compiler<'_, '_> { if !self.module.debug { return; } - let align = self.module.align(mem.opts, ty); + let align = self.types.align(mem.opts, ty); if align == 1 { return; } @@ -2299,9 +2469,10 @@ impl Compiler<'_, '_> { fn gen_local(&mut self, ty: ValType) -> u32 { // TODO: see if local reuse is necessary, right now this always // generates a new local. - match self.locals.last_mut() { + let locals = &mut self.module.funcs[self.result].locals; + match locals.last_mut() { Some((cnt, prev_ty)) if ty == *prev_ty => *cnt += 1, - _ => self.locals.push((1, ty)), + _ => locals.push((1, ty)), } self.nlocals += 1; self.nlocals - 1 @@ -2316,27 +2487,29 @@ impl Compiler<'_, '_> { self.instruction(Unreachable); } - fn finish(&mut self) -> (Vec, Vec<(usize, Trap)>) { + /// Flushes out the current `code` instructions (and `traps` if there are + /// any) into the destination function. + /// + /// This is a noop if no instructions have been encoded yet. + fn flush_code(&mut self) { + if self.code.is_empty() { + return; + } + self.module.funcs[self.result].body.push(Body::Raw( + mem::take(&mut self.code), + mem::take(&mut self.traps), + )); + } + + fn finish(mut self) { + // Append the final `end` instruction which all functions require, and + // then empty out the temporary buffer in `Compiler`. self.instruction(End); + self.flush_code(); - let mut bytes = Vec::new(); - - // Encode all locals used for this function - self.locals.len().encode(&mut bytes); - for (count, ty) in self.locals.iter() { - count.encode(&mut bytes); - ty.encode(&mut bytes); - } - - // Factor in the size of the encodings of locals into the offsets of - // traps. - for (offset, _) in self.traps.iter_mut() { - *offset += bytes.len(); - } - - // Then append the function we built and return - bytes.extend_from_slice(&self.code); - (bytes, mem::take(&mut self.traps)) + // Flag the function as "done" which helps with an assert later on in + // emission that everything was eventually finished. + self.module.funcs[self.result].filled_in = true; } /// Fetches the value contained with the local specified by `stack` and @@ -2361,8 +2534,8 @@ impl Compiler<'_, '_> { (ValType::I64, ValType::F64) => self.instruction(F64ReinterpretI64), (ValType::F64, ValType::F32) => self.instruction(F32DemoteF64), (ValType::I64, ValType::F32) => { - self.instruction(F64ReinterpretI64); - self.instruction(F32DemoteF64); + self.instruction(I32WrapI64); + self.instruction(F32ReinterpretI32); } // should not be possible given the `join` function for variants @@ -2405,8 +2578,8 @@ impl Compiler<'_, '_> { (ValType::F64, ValType::I64) => self.instruction(I64ReinterpretF64), (ValType::F32, ValType::F64) => self.instruction(F64PromoteF32), (ValType::F32, ValType::I64) => { - self.instruction(F64PromoteF32); - self.instruction(I64ReinterpretF64); + self.instruction(I32ReinterpretF32); + self.instruction(I64ExtendI32U); } // should not be possible given the `join` function for variants @@ -2654,7 +2827,7 @@ impl<'a> Source<'a> { /// offset for each memory-based type. fn record_field_srcs<'b>( &'b self, - module: &'b Module, + types: &'b ComponentTypes, fields: impl IntoIterator + 'b, ) -> impl Iterator> + 'b where @@ -2663,11 +2836,11 @@ impl<'a> Source<'a> { let mut offset = 0; fields.into_iter().map(move |ty| match self { Source::Memory(mem) => { - let mem = next_field_offset(&mut offset, module, &ty, mem); + let mem = next_field_offset(&mut offset, types, &ty, mem); Source::Memory(mem) } Source::Stack(stack) => { - let cnt = module.flatten_types(stack.opts, [ty]).len(); + let cnt = types.flatten_types(stack.opts, [ty]).len(); offset += cnt; Source::Stack(stack.slice(offset - cnt..offset)) } @@ -2677,17 +2850,17 @@ impl<'a> Source<'a> { /// Returns the corresponding discriminant source and payload source f fn payload_src( &self, - module: &Module, - size: DiscriminantSize, + types: &ComponentTypes, + info: &VariantInfo, case: &InterfaceType, ) -> Source<'a> { match self { Source::Stack(s) => { - let flat_len = module.flatten_types(s.opts, [*case]).len(); + let flat_len = types.flatten_types(s.opts, [*case]).len(); Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) } Source::Memory(mem) => { - let mem = payload_offset(size, module, case, mem); + let mem = info.payload_offset(case, mem); Source::Memory(mem) } } @@ -2705,7 +2878,7 @@ impl<'a> Destination<'a> { /// Same as `Source::record_field_srcs` but for destinations. fn record_field_dsts<'b>( &'b self, - module: &'b Module, + types: &'b ComponentTypes, fields: impl IntoIterator + 'b, ) -> impl Iterator + 'b where @@ -2714,11 +2887,11 @@ impl<'a> Destination<'a> { let mut offset = 0; fields.into_iter().map(move |ty| match self { Destination::Memory(mem) => { - let mem = next_field_offset(&mut offset, module, &ty, mem); + let mem = next_field_offset(&mut offset, types, &ty, mem); Destination::Memory(mem) } Destination::Stack(s, opts) => { - let cnt = module.flatten_types(opts, [ty]).len(); + let cnt = types.flatten_types(opts, [ty]).len(); offset += cnt; Destination::Stack(&s[offset - cnt..offset], opts) } @@ -2728,17 +2901,17 @@ impl<'a> Destination<'a> { /// Returns the corresponding discriminant source and payload source f fn payload_dst( &self, - module: &Module, - size: DiscriminantSize, + types: &ComponentTypes, + info: &VariantInfo, case: &InterfaceType, ) -> Destination { match self { Destination::Stack(s, opts) => { - let flat_len = module.flatten_types(opts, [*case]).len(); + let flat_len = types.flatten_types(opts, [*case]).len(); Destination::Stack(&s[1..][..flat_len], opts) } Destination::Memory(mem) => { - let mem = payload_offset(size, module, case, mem); + let mem = info.payload_offset(case, mem); Destination::Memory(mem) } } @@ -2754,23 +2927,37 @@ impl<'a> Destination<'a> { fn next_field_offset<'a>( offset: &mut usize, - module: &Module, + types: &ComponentTypes, field: &InterfaceType, mem: &Memory<'a>, ) -> Memory<'a> { - let (size, align) = module.size_align(mem.opts, field); + let (size, align) = types.size_align(mem.opts, field); *offset = align_to(*offset, align) + size; mem.bump(*offset - size) } -fn payload_offset<'a>( - disc_size: DiscriminantSize, - module: &Module, - case: &InterfaceType, - mem: &Memory<'a>, -) -> Memory<'a> { - let align = module.align(mem.opts, case); - mem.bump(align_to(disc_size.into(), align)) +struct VariantInfo { + size: DiscriminantSize, + align: usize, +} + +impl VariantInfo { + fn new(types: &ComponentTypes, options: &Options, iter: I) -> VariantInfo + where + I: IntoIterator, + I::IntoIter: ExactSizeIterator, + { + let iter = iter.into_iter(); + let size = DiscriminantSize::from_count(iter.len()).unwrap(); + VariantInfo { + size, + align: usize::from(size).max(iter.map(|i| types.align(options, &i)).max().unwrap_or(1)), + } + } + + fn payload_offset<'a>(&self, _case: &InterfaceType, mem: &Memory<'a>) -> Memory<'a> { + mem.bump(align_to(self.size.into(), self.align)) + } } impl<'a> Memory<'a> { diff --git a/crates/environ/src/fact/transcode.rs b/crates/environ/src/fact/transcode.rs index 865fef316e..7d72413050 100644 --- a/crates/environ/src/fact/transcode.rs +++ b/crates/environ/src/fact/transcode.rs @@ -1,15 +1,8 @@ use crate::fact::core_types::CoreTypes; use crate::MemoryIndex; use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use wasm_encoder::{EntityType, ValType}; -pub struct Transcoders { - imported: HashMap, - prev_func_imports: u32, - imports: Vec<(String, EntityType, Transcoder)>, -} - #[derive(Copy, Clone, Hash, Eq, PartialEq)] pub struct Transcoder { pub from_memory: MemoryIndex, @@ -46,33 +39,8 @@ pub enum FixedEncoding { Latin1, } -impl Transcoders { - pub fn new(prev_func_imports: u32) -> Transcoders { - Transcoders { - imported: HashMap::new(), - prev_func_imports, - imports: Vec::new(), - } - } - - pub fn import(&mut self, types: &mut CoreTypes, transcoder: Transcoder) -> u32 { - *self.imported.entry(transcoder).or_insert_with(|| { - let idx = self.prev_func_imports + (self.imports.len() as u32); - self.imports - .push((transcoder.name(), transcoder.ty(types), transcoder)); - idx - }) - } - - pub fn imports(&self) -> impl Iterator { - self.imports - .iter() - .map(|(name, ty, transcoder)| ("transcode", &name[..], *ty, transcoder)) - } -} - impl Transcoder { - fn name(&self) -> String { + pub fn name(&self) -> String { format!( "{} (mem{} => mem{})", self.op.desc(), @@ -81,7 +49,7 @@ impl Transcoder { ) } - fn ty(&self, types: &mut CoreTypes) -> EntityType { + pub fn ty(&self, types: &mut CoreTypes) -> EntityType { let from_ptr = if self.from_memory64 { ValType::I64 } else { diff --git a/crates/fuzzing/src/generators/component_types.rs b/crates/fuzzing/src/generators/component_types.rs index 2d93f29d72..bd3c64755c 100644 --- a/crates/fuzzing/src/generators/component_types.rs +++ b/crates/fuzzing/src/generators/component_types.rs @@ -8,6 +8,7 @@ use arbitrary::{Arbitrary, Unstructured}; use component_fuzz_util::{Declarations, EXPORT_FUNCTION, IMPORT_FUNCTION}; +use std::any::Any; use std::fmt::Debug; use std::ops::ControlFlow; use wasmtime::component::{self, Component, Lift, Linker, Lower, Val}; @@ -141,25 +142,29 @@ macro_rules! define_static_api_test { let mut config = Config::new(); config.wasm_component_model(true); let engine = Engine::new(&config).unwrap(); - let component = Component::new( - &engine, - declarations.make_component().as_bytes() - ).unwrap(); + let wat = declarations.make_component(); + let wat = wat.as_bytes(); + crate::oracles::log_wasm(wat); + let component = Component::new(&engine, wat).unwrap(); let mut linker = Linker::new(&engine); linker .root() .func_wrap( IMPORT_FUNCTION, - |cx: StoreContextMut<'_, ($(Option<$param>,)* Option)>, + |cx: StoreContextMut<'_, Box>, $($param_name: $param,)*| { - let ($($param_expected_name,)* result) = cx.data(); - $(assert_eq!($param_name, *$param_expected_name.as_ref().unwrap());)* - Ok(result.as_ref().unwrap().clone()) + log::trace!("received parameters {:?}", ($(&$param_name,)*)); + let data: &($($param,)* R,) = + cx.data().downcast_ref().unwrap(); + let ($($param_expected_name,)* result,) = data; + $(assert_eq!($param_name, *$param_expected_name);)* + log::trace!("returning result {:?}", result); + Ok(result.clone()) }, ) .unwrap(); - let mut store = Store::new(&engine, Default::default()); + let mut store: Store> = Store::new(&engine, Box::new(())); let instance = linker.instantiate(&mut store, &component).unwrap(); let func = instance .get_typed_func::<($($param,)*), R, _>(&mut store, EXPORT_FUNCTION) @@ -168,9 +173,17 @@ macro_rules! define_static_api_test { while input.arbitrary()? { $(let $param_name = input.arbitrary::<$param>()?;)* let result = input.arbitrary::()?; - *store.data_mut() = ($(Some($param_name.clone()),)* Some(result.clone())); - - assert_eq!(func.call(&mut store, ($($param_name,)*)).unwrap(), result); + *store.data_mut() = Box::new(( + $($param_name.clone(),)* + result.clone(), + )); + log::trace!( + "passing in parameters {:?}", + ($(&$param_name,)*), + ); + let actual = func.call(&mut store, ($($param_name,)*)).unwrap(); + log::trace!("got result {:?}", actual); + assert_eq!(actual, result); func.post_return(&mut store).unwrap(); } diff --git a/crates/fuzzing/src/oracles.rs b/crates/fuzzing/src/oracles.rs index 4e7d090c4f..bdaca94ae5 100644 --- a/crates/fuzzing/src/oracles.rs +++ b/crates/fuzzing/src/oracles.rs @@ -1089,20 +1089,25 @@ pub fn dynamic_component_api_target(input: &mut arbitrary::Unstructured) -> arbi let engine = component_test_util::engine(); let mut store = Store::new(&engine, (Box::new([]) as Box<[Val]>, None)); - let component = - Component::new(&engine, case.declarations().make_component().as_bytes()).unwrap(); + let wat = case.declarations().make_component(); + let wat = wat.as_bytes(); + log_wasm(wat); + let component = Component::new(&engine, wat).unwrap(); let mut linker = Linker::new(&engine); linker .root() .func_new(&component, IMPORT_FUNCTION, { move |cx: StoreContextMut<'_, (Box<[Val]>, Option)>, args: &[Val]| -> Result { + log::trace!("received arguments {args:?}"); let (expected_args, result) = cx.data(); assert_eq!(args.len(), expected_args.len()); for (expected, actual) in expected_args.iter().zip(args) { assert_eq!(expected, actual); } - Ok(result.as_ref().unwrap().clone()) + let result = result.as_ref().unwrap().clone(); + log::trace!("returning result {result:?}"); + Ok(result) } }) .unwrap(); @@ -1122,10 +1127,10 @@ pub fn dynamic_component_api_target(input: &mut arbitrary::Unstructured) -> arbi *store.data_mut() = (args.clone(), Some(result.clone())); - assert_eq!( - func.call_and_post_return(&mut store, &args).unwrap(), - result - ); + log::trace!("passing args {args:?}"); + let actual = func.call_and_post_return(&mut store, &args).unwrap(); + log::trace!("received return {actual:?}"); + assert_eq!(actual, result); } Ok(()) diff --git a/crates/misc/component-fuzz-util/src/lib.rs b/crates/misc/component-fuzz-util/src/lib.rs index 9b14266dcd..2d5b75fd92 100644 --- a/crates/misc/component-fuzz-util/src/lib.rs +++ b/crates/misc/component-fuzz-util/src/lib.rs @@ -8,7 +8,8 @@ use arbitrary::{Arbitrary, Unstructured}; use proc_macro2::{Ident, TokenStream}; -use quote::{format_ident, quote}; +use quote::{format_ident, quote, ToTokens}; +use std::borrow::Cow; use std::fmt::{self, Debug, Write}; use std::iter; use std::ops::Deref; @@ -328,7 +329,7 @@ fn variant_size_and_alignment<'a>( } } -fn make_import_and_export(params: &[Type], result: &Type) -> Box { +fn make_import_and_export(params: &[Type], result: &Type) -> String { let params_lowered = params .iter() .flat_map(|ty| ty.lowered()) @@ -400,7 +401,6 @@ fn make_import_and_export(params: &[Type], result: &Type) -> Box { )"# ) } - .into() } fn make_rust_name(name_counter: &mut u32) -> Ident { @@ -509,7 +509,7 @@ pub fn rust_type(ty: &Type, name_counter: &mut u32, declarations: &mut TokenStre let name = make_rust_name(name_counter); declarations.extend(quote! { - #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Clone, Arbitrary)] + #[derive(ComponentType, Lift, Lower, PartialEq, Debug, Copy, Clone, Arbitrary)] #[component(enum)] enum #name { #cases @@ -677,13 +677,17 @@ fn write_component_type( #[derive(Debug)] pub struct Declarations { /// Type declarations (if any) referenced by `params` and/or `result` - pub types: Box, + pub types: Cow<'static, str>, /// Parameter declarations used for the imported and exported functions - pub params: Box, + pub params: Cow<'static, str>, /// Result declaration used for the imported and exported functions - pub result: Box, + pub result: Cow<'static, str>, /// A WAT fragment representing the core function import and export to use for testing - pub import_and_export: Box, + pub import_and_export: Cow<'static, str>, + /// String encoding to use for host -> component + pub encoding1: StringEncoding, + /// String encoding to use for component -> host + pub encoding2: StringEncoding, } impl Declarations { @@ -694,7 +698,44 @@ impl Declarations { params, result, import_and_export, + encoding1, + encoding2, } = self; + let mk_component = |name: &str, encoding: StringEncoding| { + format!( + r#" + (component ${name} + (import "echo" (func $f (type $sig))) + + (core instance $libc (instantiate $libc)) + + (core func $f_lower (canon lower + (func $f) + (memory $libc "memory") + (realloc (func $libc "realloc")) + string-encoding={encoding} + )) + + (core instance $i (instantiate $m + (with "libc" (instance $libc)) + (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower)))) + )) + + (func (export "echo") (type $sig) + (canon lift + (core func $i "echo") + (memory $libc "memory") + (realloc (func $libc "realloc")) + string-encoding={encoding} + ) + ) + ) + "# + ) + }; + + let c1 = mk_component("c1", *encoding2); + let c2 = mk_component("c2", *encoding1); format!( r#" @@ -704,18 +745,6 @@ impl Declarations { {REALLOC_AND_FREE} ) - (core instance $libc (instantiate $libc)) - - {types} - - (import "{IMPORT_FUNCTION}" (func $f {params} {result})) - - (core func $f_lower (canon lower - (func $f) - (memory $libc "memory") - (realloc (func $libc "realloc")) - )) - (core module $m (memory (import "libc" "memory") 1) (func $realloc (import "libc" "realloc") (param i32 i32 i32 i32) (result i32)) @@ -723,18 +752,16 @@ impl Declarations { {import_and_export} ) - (core instance $i (instantiate $m - (with "libc" (instance $libc)) - (with "host" (instance (export "{IMPORT_FUNCTION}" (func $f_lower)))) - )) + {types} - (func (export "echo") {params} {result} - (canon lift - (core func $i "echo") - (memory $libc "memory") - (realloc (func $libc "realloc")) - ) - ) + (type $sig (func {params} {result})) + (import "{IMPORT_FUNCTION}" (func $f (type $sig))) + + {c1} + {c2} + (instance $c1 (instantiate $c1 (with "echo" (func $f)))) + (instance $c2 (instantiate $c2 (with "echo" (func $c1 "echo")))) + (export "echo" (func $c2 "echo")) )"#, ) .into() @@ -748,6 +775,10 @@ pub struct TestCase { pub params: Box<[Type]>, /// The type of the result to be returned by the function pub result: Type, + /// String encoding to use from host-to-component. + pub encoding1: StringEncoding, + /// String encoding to use from component-to-host. + pub encoding2: StringEncoding, } impl TestCase { @@ -781,7 +812,9 @@ impl TestCase { types: types.into(), params, result, - import_and_export, + import_and_export: import_and_export.into(), + encoding1: self.encoding1, + encoding2: self.encoding2, } } } @@ -795,6 +828,36 @@ impl<'a> Arbitrary<'a> for TestCase { .take(MAX_ARITY) .collect::>>()?, result: input.arbitrary()?, + encoding1: input.arbitrary()?, + encoding2: input.arbitrary()?, }) } } + +#[derive(Copy, Clone, Debug, Arbitrary)] +pub enum StringEncoding { + Utf8, + Utf16, + Latin1OrUtf16, +} + +impl fmt::Display for StringEncoding { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + StringEncoding::Utf8 => fmt::Display::fmt(&"utf8", f), + StringEncoding::Utf16 => fmt::Display::fmt(&"utf16", f), + StringEncoding::Latin1OrUtf16 => fmt::Display::fmt(&"latin1+utf16", f), + } + } +} + +impl ToTokens for StringEncoding { + fn to_tokens(&self, tokens: &mut TokenStream) { + let me = match self { + StringEncoding::Utf8 => quote!(Utf8), + StringEncoding::Utf16 => quote!(Utf16), + StringEncoding::Latin1OrUtf16 => quote!(Latin1OrUtf16), + }; + tokens.extend(quote!(component_fuzz_util::StringEncoding::#me)); + } +} diff --git a/crates/wasmtime/src/component/values.rs b/crates/wasmtime/src/component/values.rs index 7c3f550152..65fd2280b9 100644 --- a/crates/wasmtime/src/component/values.rs +++ b/crates/wasmtime/src/component/values.rs @@ -4,12 +4,13 @@ use crate::store::StoreOpaque; use crate::{AsContextMut, StoreContextMut, ValRaw}; use anyhow::{anyhow, bail, Context, Error, Result}; use std::collections::HashMap; +use std::fmt; use std::iter; use std::mem::MaybeUninit; use std::ops::Deref; use wasmtime_component_util::{DiscriminantSize, FlagsSize}; -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone)] pub struct List { ty: types::List, values: Box<[Val]>, @@ -45,7 +46,17 @@ impl Deref for List { } } -#[derive(Debug, PartialEq, Eq, Clone)] +impl fmt::Debug for List { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_list(); + for val in self.iter() { + f.entry(val); + } + f.finish() + } +} + +#[derive(PartialEq, Eq, Clone)] pub struct Record { ty: types::Record, values: Box<[Val]>, @@ -105,6 +116,16 @@ impl Record { } } +impl fmt::Debug for Record { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut f = f.debug_struct("Record"); + for (name, val) in self.fields() { + f.field(name, val); + } + f.finish() + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Tuple { ty: types::Tuple, @@ -144,7 +165,7 @@ impl Tuple { } } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone)] pub struct Variant { ty: types::Variant, discriminant: u32, @@ -197,6 +218,14 @@ impl Variant { } } +impl fmt::Debug for Variant { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple(self.discriminant()) + .field(self.payload()) + .finish() + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Enum { ty: types::Enum, @@ -273,7 +302,7 @@ impl Union { } } -#[derive(Debug, PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone)] pub struct Option { ty: types::Option, discriminant: u32, @@ -313,7 +342,13 @@ impl Option { } } -#[derive(Debug, PartialEq, Eq, Clone)] +impl fmt::Debug for Option { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value().fmt(f) + } +} + +#[derive(PartialEq, Eq, Clone)] pub struct Expected { ty: types::Expected, discriminant: u32, @@ -358,7 +393,13 @@ impl Expected { } } -#[derive(Debug, PartialEq, Eq, Clone)] +impl fmt::Debug for Expected { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.value().fmt(f) + } +} + +#[derive(PartialEq, Eq, Clone)] pub struct Flags { ty: types::Flags, count: u32, @@ -408,6 +449,16 @@ impl Flags { } } +impl fmt::Debug for Flags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut set = f.debug_set(); + for flag in self.flags() { + set.entry(&flag); + } + set.finish() + } +} + /// Represents possible runtime values which a component function can either consume or produce #[derive(Debug, PartialEq, Eq, Clone)] pub enum Val { diff --git a/fuzz/build.rs b/fuzz/build.rs index b8b45a36f4..97c6e64ddf 100644 --- a/fuzz/build.rs +++ b/fuzz/build.rs @@ -77,6 +77,8 @@ mod component { params, result, import_and_export, + encoding1, + encoding2, } = case.declarations(); let test = format_ident!("static_api_test{}", case.params.len()); @@ -95,11 +97,16 @@ mod component { let test = quote!(#index => component_types::#test::<#rust_params #rust_result>( input, - &Declarations { - types: #types.into(), - params: #params.into(), - result: #result.into(), - import_and_export: #import_and_export.into() + { + static DECLS: Declarations = Declarations { + types: Cow::Borrowed(#types), + params: Cow::Borrowed(#params), + result: Cow::Borrowed(#result), + import_and_export: Cow::Borrowed(#import_and_export), + encoding1: #encoding1, + encoding2: #encoding2, + }; + &DECLS } ),); @@ -116,6 +123,7 @@ mod component { use std::sync::{Arc, Once}; use wasmtime::component::{ComponentType, Lift, Lower}; use wasmtime_fuzzing::generators::component_types; + use std::borrow::Cow; const SEED: u64 = #seed; diff --git a/tests/misc_testsuite/component-model/fused.wast b/tests/misc_testsuite/component-model/fused.wast index 6de762471e..77dca93edc 100644 --- a/tests/misc_testsuite/component-model/fused.wast +++ b/tests/misc_testsuite/component-model/fused.wast @@ -925,7 +925,7 @@ (i32.eqz (local.get 0)) if (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) - (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 8)) (unreachable)) + (if (f32.ne (f32.reinterpret_i32 (i32.wrap_i64 (local.get 2))) (f32.const 8)) (unreachable)) else (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 9)) (unreachable)) @@ -935,7 +935,7 @@ (i32.eqz (local.get 0)) if (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) - (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 10)) (unreachable)) + (if (f32.ne (f32.reinterpret_i32 (i32.wrap_i64 (local.get 2))) (f32.const 10)) (unreachable)) else (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) (if (i64.ne (local.get 2) (i64.const 11)) (unreachable)) @@ -983,10 +983,10 @@ (call $c (i32.const 0) (i32.const 0) (i64.const 6)) (call $c (i32.const 1) (i32.const 1) (i64.reinterpret_f64 (f64.const 7))) - (call $d (i32.const 0) (i32.const 0) (i64.reinterpret_f64 (f64.const 8))) + (call $d (i32.const 0) (i32.const 0) (i64.extend_i32_u (i32.reinterpret_f32 (f32.const 8)))) (call $d (i32.const 1) (i32.const 1) (i64.reinterpret_f64 (f64.const 9))) - (call $e (i32.const 0) (i32.const 0) (i64.reinterpret_f64 (f64.const 10))) + (call $e (i32.const 0) (i32.const 0) (i64.extend_i32_u (i32.reinterpret_f32 (f32.const 10)))) (call $e (i32.const 1) (i32.const 1) (i64.const 11)) ) (start $start)