diff --git a/Cargo.lock b/Cargo.lock index e475783635..2c6b5766b5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3507,9 +3507,22 @@ dependencies = [ "wasm-encoder", "wasmparser", "wasmprinter", + "wasmtime-component-util", "wasmtime-types", ] +[[package]] +name = "wasmtime-environ-fuzz" +version = "0.0.0" +dependencies = [ + "arbitrary", + "env_logger 0.9.0", + "libfuzzer-sys", + "wasmparser", + "wasmprinter", + "wasmtime-environ", +] + [[package]] name = "wasmtime-fiber" version = "0.40.0" diff --git a/Cargo.toml b/Cargo.toml index 561dd6e277..558528a410 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ members = [ "crates/bench-api", "crates/c-api", "crates/cli-flags", + "crates/environ/fuzz", "examples/fib-debug/wasm", "examples/wasi/wasm", "examples/tokio/wasm", diff --git a/crates/environ/Cargo.toml b/crates/environ/Cargo.toml index ff2d9707de..2f8622ec95 100644 --- a/crates/environ/Cargo.toml +++ b/crates/environ/Cargo.toml @@ -24,9 +24,14 @@ object = { version = "0.29.0", default-features = false, features = ['read_core' target-lexicon = "0.12" wasm-encoder = { version = "0.14.0", optional = true } wasmprinter = { version = "0.2.37", optional = true } +wasmtime-component-util = { path = "../component-util", version = "=0.40.0", optional = true } [badges] maintenance = { status = "actively-developed" } [features] -component-model = ["dep:wasm-encoder", "dep:wasmprinter"] +component-model = [ + "dep:wasm-encoder", + "dep:wasmprinter", + "dep:wasmtime-component-util", +] diff --git a/crates/environ/fuzz/.gitignore b/crates/environ/fuzz/.gitignore new file mode 100644 index 0000000000..b400c27826 --- /dev/null +++ b/crates/environ/fuzz/.gitignore @@ -0,0 +1,2 @@ +corpus +artifacts diff --git a/crates/environ/fuzz/Cargo.toml b/crates/environ/fuzz/Cargo.toml new file mode 100644 index 0000000000..c4cd8a6c3a --- /dev/null +++ b/crates/environ/fuzz/Cargo.toml @@ -0,0 +1,27 @@ +[package] +name = "wasmtime-environ-fuzz" +version = "0.0.0" +authors = ["Automatically generated"] +publish = false +edition = "2018" + +[package.metadata] +cargo-fuzz = true + +[dependencies] +arbitrary = { version = "1.1.0", features = ["derive"] } +env_logger = "0.9.0" +libfuzzer-sys = "0.4" +wasmparser = "0.87.0" +wasmprinter = "0.2.37" +wasmtime-environ = { path = ".." } + +[[bin]] +name = "fact-valid-module" +path = "fuzz_targets/fact-valid-module.rs" +test = false +doc = false +required-features = ["component-model"] + +[features] +component-model = ["wasmtime-environ/component-model"] diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs new file mode 100644 index 0000000000..246c450f94 --- /dev/null +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -0,0 +1,252 @@ +//! A simple fuzzer for FACT +//! +//! This is an intentionally small fuzzer which is intended to only really be +//! used during the development of FACT itself when generating adapter modules. +//! This creates arbitrary adapter signatures and then generates the required +//! trampoline for that adapter ensuring that the final output wasm module is a +//! valid wasm module. This doesn't actually validate anything about the +//! correctness of the trampoline, only that it's valid wasm. + +#![no_main] + +use arbitrary::{Arbitrary, Unstructured}; +use libfuzzer_sys::fuzz_target; +use std::fmt; +use wasmparser::{Validator, WasmFeatures}; +use wasmtime_environ::component::*; +use wasmtime_environ::fact::Module; + +#[derive(Arbitrary, Debug)] +struct GenAdapterModule { + debug: bool, + adapters: Vec, +} + +#[derive(Arbitrary, Debug)] +struct GenAdapter { + ty: FuncType, + post_return: bool, + lift_memory64: bool, + lower_memory64: bool, + lift_encoding: GenStringEncoding, + lower_encoding: GenStringEncoding, +} + +#[derive(Arbitrary, Debug)] +struct FuncType { + params: Vec, + result: ValType, +} + +#[derive(Arbitrary, Debug)] +enum ValType { + Unit, + U8, + S8, + U16, + S16, + U32, + S32, + U64, + S64, + Float32, + Float64, + Record(Vec), + Tuple(Vec), + Variant(NonZeroLenVec), +} + +#[derive(Copy, Clone, Arbitrary, Debug)] +enum GenStringEncoding { + Utf8, + Utf16, + CompactUtf16, +} + +pub struct NonZeroLenVec(Vec); + +impl<'a, T: Arbitrary<'a>> Arbitrary<'a> for NonZeroLenVec { + fn arbitrary(u: &mut Unstructured<'a>) -> arbitrary::Result { + let mut items = Vec::arbitrary(u)?; + if items.is_empty() { + items.push(u.arbitrary()?); + } + Ok(NonZeroLenVec(items)) + } +} + +impl fmt::Debug for NonZeroLenVec { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +fuzz_target!(|module: GenAdapterModule| { + drop(env_logger::try_init()); + + let mut types = ComponentTypesBuilder::default(); + + // Manufactures a unique `CoreDef` so all function imports get unique + // function imports. + let mut next_def = 0; + let mut dummy_def = || { + next_def += 1; + CoreDef::Adapter(AdapterIndex::from_u32(next_def)) + }; + + // Manufactures a `CoreExport` for a memory with the shape specified. Note + // that we can't import as many memories as functions so these are + // intentionally limited. Once a handful of memories are generated of each + // type then they start getting reused. + let mut next_memory = 0; + let mut memories32 = Vec::new(); + let mut memories64 = Vec::new(); + let mut dummy_memory = |memory64: bool| { + let dst = if memory64 { + &mut memories64 + } else { + &mut memories32 + }; + let idx = if dst.len() < 5 { + next_memory += 1; + dst.push(next_memory - 1); + next_memory - 1 + } else { + dst[0] + }; + CoreExport { + instance: RuntimeInstanceIndex::from_u32(idx), + item: ExportItem::Name(String::new()), + } + }; + + let mut adapters = Vec::new(); + for adapter in module.adapters.iter() { + let mut params = Vec::new(); + for param in adapter.ty.params.iter() { + params.push((None, intern(&mut types, param))); + } + let result = intern(&mut types, &adapter.ty.result); + let signature = types.add_func_type(TypeFunc { + params: params.into(), + result, + }); + adapters.push(Adapter { + lift_ty: signature, + lower_ty: signature, + lower_options: AdapterOptions { + instance: RuntimeComponentInstanceIndex::from_u32(0), + string_encoding: adapter.lower_encoding.into(), + memory64: adapter.lower_memory64, + // Pessimistically assume that memory/realloc are going to be + // required for this trampoline and provide it. Avoids doing + // calculations to figure out whether they're necessary and + // simplifies the fuzzer here without reducing coverage within FACT + // itself. + memory: Some(dummy_memory(adapter.lower_memory64)), + realloc: Some(dummy_def()), + // Lowering never allows `post-return` + post_return: None, + }, + lift_options: AdapterOptions { + instance: RuntimeComponentInstanceIndex::from_u32(1), + string_encoding: adapter.lift_encoding.into(), + memory64: adapter.lift_memory64, + memory: Some(dummy_memory(adapter.lift_memory64)), + realloc: Some(dummy_def()), + post_return: if adapter.post_return { + Some(dummy_def()) + } else { + None + }, + }, + func: dummy_def(), + }); + } + let types = types.finish(); + let mut fact_module = Module::new(&types, module.debug); + for (i, adapter) in adapters.iter().enumerate() { + fact_module.adapt(&format!("adapter{i}"), adapter); + } + let wasm = fact_module.encode(); + let result = Validator::new_with_features(WasmFeatures { + multi_memory: true, + memory64: true, + ..WasmFeatures::default() + }) + .validate_all(&wasm); + + let err = match result { + Ok(_) => return, + Err(e) => e, + }; + eprintln!("invalid wasm module: {err:?}"); + for adapter in module.adapters.iter() { + eprintln!("adapter: {adapter:?}"); + } + std::fs::write("invalid.wasm", &wasm).unwrap(); + match wasmprinter::print_bytes(&wasm) { + Ok(s) => std::fs::write("invalid.wat", &s).unwrap(), + Err(_) => drop(std::fs::remove_file("invalid.wat")), + } + + panic!() +}); + +fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { + match ty { + ValType::Unit => InterfaceType::Unit, + ValType::U8 => InterfaceType::U8, + ValType::S8 => InterfaceType::S8, + ValType::U16 => InterfaceType::U16, + ValType::S16 => InterfaceType::S16, + ValType::U32 => InterfaceType::U32, + ValType::S32 => InterfaceType::S32, + ValType::U64 => InterfaceType::U64, + ValType::S64 => InterfaceType::S64, + ValType::Float32 => InterfaceType::Float32, + ValType::Float64 => InterfaceType::Float64, + ValType::Record(tys) => { + let ty = TypeRecord { + fields: tys + .iter() + .enumerate() + .map(|(i, ty)| RecordField { + name: format!("f{i}"), + ty: intern(types, ty), + }) + .collect(), + }; + InterfaceType::Record(types.add_record_type(ty)) + } + ValType::Tuple(tys) => { + let ty = TypeTuple { + types: tys.iter().map(|ty| intern(types, ty)).collect(), + }; + InterfaceType::Tuple(types.add_tuple_type(ty)) + } + ValType::Variant(NonZeroLenVec(cases)) => { + let ty = TypeVariant { + cases: cases + .iter() + .enumerate() + .map(|(i, ty)| VariantCase { + name: format!("c{i}"), + ty: intern(types, ty), + }) + .collect(), + }; + InterfaceType::Variant(types.add_variant_type(ty)) + } + } +} + +impl From for StringEncoding { + fn from(gen: GenStringEncoding) -> StringEncoding { + match gen { + GenStringEncoding::Utf8 => StringEncoding::Utf8, + GenStringEncoding::Utf16 => StringEncoding::Utf16, + GenStringEncoding::CompactUtf16 => StringEncoding::CompactUtf16, + } + } +} diff --git a/crates/environ/src/component/translate/adapt.rs b/crates/environ/src/component/translate/adapt.rs index e097e05fc2..b8afc0557a 100644 --- a/crates/environ/src/component/translate/adapt.rs +++ b/crates/environ/src/component/translate/adapt.rs @@ -159,6 +159,8 @@ pub struct AdapterOptions { pub string_encoding: StringEncoding, /// An optional memory definition supplied. pub memory: Option>, + /// If `memory` is specified, whether it's a 64-bit memory. + pub memory64: bool, /// An optional definition of `realloc` to used. pub realloc: Option, /// An optional definition of a `post-return` to use. @@ -563,6 +565,7 @@ impl DefinedItems { let AdapterOptions { instance: _, string_encoding: _, + memory64: _, memory, realloc, post_return, diff --git a/crates/environ/src/component/translate/inline.rs b/crates/environ/src/component/translate/inline.rs index 0cae7a624e..b7507cd362 100644 --- a/crates/environ/src/component/translate/inline.rs +++ b/crates/environ/src/component/translate/inline.rs @@ -47,7 +47,7 @@ use crate::component::translate::adapt::{Adapter, AdapterOptions, Adapters}; use crate::component::translate::*; -use crate::{PrimaryMap, SignatureIndex}; +use crate::{EntityType, PrimaryMap, SignatureIndex}; use indexmap::IndexMap; pub(super) fn run( @@ -67,6 +67,7 @@ pub(super) fn run( runtime_post_return_interner: Default::default(), runtime_memory_interner: Default::default(), runtime_always_trap_interner: Default::default(), + runtime_instances: PrimaryMap::default(), }; // The initial arguments to the root component are all host imports. This @@ -145,6 +146,9 @@ struct Inliner<'a> { runtime_post_return_interner: HashMap, runtime_memory_interner: HashMap, RuntimeMemoryIndex>, runtime_always_trap_interner: HashMap, + + /// Origin information about where each runtime instance came from + runtime_instances: PrimaryMap, } /// A "stack frame" as part of the inlining process, or the progress through @@ -540,6 +544,7 @@ impl<'a> Inliner<'a> { // and an initializer is recorded to indicate that it's being // instantiated. ModuleInstantiate(module, args) => { + let instance_module; let init = match &frame.modules[*module] { ModuleDef::Static(idx) => { let mut defs = Vec::new(); @@ -549,6 +554,7 @@ impl<'a> Inliner<'a> { self.core_def_of_module_instance_export(frame, instance, name), ); } + instance_module = InstanceModule::Static(*idx); InstantiateModule::Static(*idx, defs.into()) } ModuleDef::Import(path, ty) => { @@ -562,12 +568,15 @@ impl<'a> Inliner<'a> { .insert(name.to_string(), def); } let index = self.runtime_import(path); + instance_module = InstanceModule::Import(*ty); InstantiateModule::Import(index, defs) } }; let idx = RuntimeInstanceIndex::from_u32(self.result.num_runtime_instances); self.result.num_runtime_instances += 1; + let idx2 = self.runtime_instances.push(instance_module); + assert_eq!(idx, idx2); self.result .initializers .push(GlobalInitializer::InstantiateModule(init)); @@ -822,12 +831,32 @@ impl<'a> Inliner<'a> { _ => unreachable!(), }) }); + let memory64 = match &memory { + Some(memory) => match &self.runtime_instances[memory.instance] { + InstanceModule::Static(idx) => match &memory.item { + ExportItem::Index(i) => { + let plan = &self.nested_modules[*idx].module.memory_plans[*i]; + plan.memory.memory64 + } + ExportItem::Name(_) => unreachable!(), + }, + InstanceModule::Import(ty) => match &memory.item { + ExportItem::Name(name) => match self.types[*ty].exports[name] { + EntityType::Memory(m) => m.memory64, + _ => unreachable!(), + }, + ExportItem::Index(_) => unreachable!(), + }, + }, + None => false, + }; let realloc = options.realloc.map(|i| frame.funcs[i].clone()); let post_return = options.post_return.map(|i| frame.funcs[i].clone()); AdapterOptions { instance: frame.instance, string_encoding: options.string_encoding, memory, + memory64, realloc, post_return, } @@ -1064,3 +1093,8 @@ impl<'a> ComponentItemDef<'a> { Ok(item) } } + +enum InstanceModule { + Static(StaticModuleIndex), + Import(TypeModuleIndex), +} diff --git a/crates/environ/src/component/types.rs b/crates/environ/src/component/types.rs index 040c24318f..4fec635477 100644 --- a/crates/environ/src/component/types.rs +++ b/crates/environ/src/component/types.rs @@ -574,7 +574,7 @@ impl ComponentTypesBuilder { .collect(), result: self.valtype(&ty.result), }; - intern(&mut self.functions, &mut self.component_types.functions, ty) + self.add_func_type(ty) } fn defined_type(&mut self, ty: &wasmparser::ComponentDefinedType<'_>) -> InterfaceType { @@ -636,7 +636,7 @@ impl ComponentTypesBuilder { }) .collect(), }; - intern(&mut self.records, &mut self.component_types.records, record) + self.add_record_type(record) } fn variant_type(&mut self, cases: &[wasmparser::VariantCase<'_>]) -> TypeVariantIndex { @@ -654,18 +654,14 @@ impl ComponentTypesBuilder { }) .collect(), }; - intern( - &mut self.variants, - &mut self.component_types.variants, - variant, - ) + self.add_variant_type(variant) } fn tuple_type(&mut self, types: &[wasmparser::ComponentValType]) -> TypeTupleIndex { let tuple = TypeTuple { types: types.iter().map(|ty| self.valtype(ty)).collect(), }; - intern(&mut self.tuples, &mut self.component_types.tuples, tuple) + self.add_tuple_type(tuple) } fn flags_type(&mut self, flags: &[&str]) -> TypeFlagsIndex { @@ -704,6 +700,26 @@ impl ComponentTypesBuilder { expected, ) } + + /// Interns a new function type within this type information. + pub fn add_func_type(&mut self, ty: TypeFunc) -> TypeFuncIndex { + intern(&mut self.functions, &mut self.component_types.functions, ty) + } + + /// Interns a new record type within this type information. + pub fn add_record_type(&mut self, ty: TypeRecord) -> TypeRecordIndex { + intern(&mut self.records, &mut self.component_types.records, ty) + } + + /// Interns a new tuple type within this type information. + pub fn add_tuple_type(&mut self, ty: TypeTuple) -> TypeTupleIndex { + intern(&mut self.tuples, &mut self.component_types.tuples, ty) + } + + /// Interns a new variant type within this type information. + pub fn add_variant_type(&mut self, ty: TypeVariant) -> TypeVariantIndex { + intern(&mut self.variants, &mut self.component_types.variants, ty) + } } // Forward the indexing impl to the internal `TypeTables` diff --git a/crates/environ/src/fact.rs b/crates/environ/src/fact.rs index bb21075021..b7c0dd5d87 100644 --- a/crates/environ/src/fact.rs +++ b/crates/environ/src/fact.rs @@ -23,6 +23,7 @@ use crate::component::{ }; use crate::{FuncIndex, GlobalIndex, MemoryIndex}; use std::collections::HashMap; +use std::mem; use wasm_encoder::*; mod core_types; @@ -90,6 +91,7 @@ enum Context { } impl<'a> Module<'a> { + /// Creates an empty module. pub fn new(types: &'a ComponentTypes, debug: bool) -> Module<'a> { Module { debug, @@ -110,20 +112,24 @@ impl<'a> Module<'a> { /// The `name` provided is the export name of the adapter from the final /// module, and `adapter` contains all metadata necessary for compilation. pub fn adapt(&mut self, name: &str, adapter: &Adapter) { - // Import core wasm function which was lifted using its appropriate + // Import any items required by the various canonical options + // (memories, reallocs, etc) + let mut lift = self.import_options(adapter.lift_ty, &adapter.lift_options); + let lower = self.import_options(adapter.lower_ty, &adapter.lower_options); + + // Lowering options are not allowed to specify post-return as per the + // current canonical abi specification. + assert!(adapter.lower_options.post_return.is_none()); + + // Import the core wasm function which was lifted using its appropriate // signature since the exported function this adapter generates will // call the lifted function. - let signature = self.signature(adapter.lift_ty, Context::Lift); + let signature = self.signature(&lift, Context::Lift); let ty = self .core_types .function(&signature.params, &signature.results); let callee = self.import_func("callee", name, ty, adapter.func.clone()); - // Next import any items required by the various canonical options - // (memories, reallocs, etc) - let mut lift = self.import_options(adapter.lift_ty, &adapter.lift_options); - let lower = self.import_options(adapter.lower_ty, &adapter.lower_options); - // Handle post-return specifically here where we have `core_ty` and the // results of `core_ty` are the parameters to the post-return function. lift.post_return = adapter.lift_options.post_return.as_ref().map(|func| { @@ -131,10 +137,6 @@ impl<'a> Module<'a> { self.import_func("post_return", name, ty, func.clone()) }); - // Lowering options are not allowed to specify post-return as per the - // current canonical abi specification. - assert!(adapter.lower_options.post_return.is_none()); - self.adapters.push(AdapterData { name: name.to_string(), lift, @@ -151,10 +153,10 @@ impl<'a> Module<'a> { instance, string_encoding, memory, + memory64, realloc, post_return: _, // handled above } = options; - let memory64 = false; // FIXME(#4311) should be plumbed from somewhere let flags = self.import_global( "flags", &format!("instance{}", instance.as_u32()), @@ -172,13 +174,17 @@ impl<'a> Module<'a> { minimum: 0, maximum: None, shared: false, - memory64, + memory64: *memory64, }, memory.clone().into(), ) }); let realloc = realloc.as_ref().map(|func| { - let ptr = if memory64 { ValType::I64 } else { ValType::I32 }; + let ptr = if *memory64 { + ValType::I64 + } else { + ValType::I32 + }; let ty = self.core_types.function(&[ptr, ptr, ptr, ptr], &[ptr]); self.import_func("realloc", "", ty, func.clone()) }); @@ -186,7 +192,7 @@ impl<'a> Module<'a> { ty, string_encoding: *string_encoding, flags, - memory64, + memory64: *memory64, memory, realloc, post_return: None, @@ -245,26 +251,27 @@ impl<'a> Module<'a> { ret } + /// Encodes this module into a WebAssembly binary. pub fn encode(&mut self) -> Vec { let mut funcs = FunctionSection::new(); let mut code = CodeSection::new(); let mut exports = ExportSection::new(); let mut traps = traps::TrapSection::default(); + let mut types = mem::take(&mut self.core_types); for adapter in self.adapters.iter() { let idx = self.core_funcs + funcs.len(); exports.export(&adapter.name, ExportKind::Func, idx); - let signature = self.signature(adapter.lower.ty, Context::Lower); - let ty = self - .core_types - .function(&signature.params, &signature.results); + let signature = self.signature(&adapter.lower, Context::Lower); + let ty = types.function(&signature.params, &signature.results); funcs.function(ty); - let (function, func_traps) = trampoline::compile(self, adapter); + let (function, func_traps) = trampoline::compile(self, &mut types, adapter); code.raw(&function); traps.append(idx, func_traps); } + self.core_types = types; let traps = traps.finish(); let mut result = wasm_encoder::Module::new(); @@ -288,3 +295,13 @@ impl<'a> Module<'a> { &self.imports } } + +impl Options { + fn ptr(&self) -> ValType { + if self.memory64 { + ValType::I64 + } else { + ValType::I32 + } + } +} diff --git a/crates/environ/src/fact/signature.rs b/crates/environ/src/fact/signature.rs index e289103da1..7194a36757 100644 --- a/crates/environ/src/fact/signature.rs +++ b/crates/environ/src/fact/signature.rs @@ -1,11 +1,12 @@ //! Size, align, and flattening information about component model types. -use crate::component::{InterfaceType, TypeFuncIndex, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; -use crate::fact::{Context, Module}; +use crate::component::{InterfaceType, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS}; +use crate::fact::{Context, Module, Options}; use wasm_encoder::ValType; /// Metadata about a core wasm signature which is created for a component model /// signature. +#[derive(Debug)] pub struct Signature { /// Core wasm parameters. pub params: Vec, @@ -33,13 +34,14 @@ impl Module<'_> { /// This is used to generate the core wasm signatures for functions that are /// imported (matching whatever was `canon lift`'d) and functions that are /// exported (matching the generated function from `canon lower`). - pub(super) fn signature(&self, ty: TypeFuncIndex, context: Context) -> Signature { - let ty = &self.types[ty]; + pub(super) fn signature(&self, options: &Options, context: Context) -> Signature { + let ty = &self.types[options.ty]; + let ptr_ty = options.ptr(); let mut params = self.flatten_types(ty.params.iter().map(|(_, ty)| *ty)); let mut params_indirect = false; if params.len() > MAX_FLAT_PARAMS { - params = vec![ValType::I32]; + params = vec![ptr_ty]; params_indirect = true; } @@ -51,13 +53,13 @@ impl Module<'_> { // For a lifted function too-many-results gets translated to a // returned pointer where results are read from. The callee // allocates space here. - Context::Lift => results = vec![ValType::I32], + Context::Lift => results = vec![ptr_ty], // For a lowered function too-many-results becomes a return // pointer which is passed as the last argument. The caller // allocates space here. Context::Lower => { results.truncate(0); - params.push(ValType::I32); + params.push(ptr_ty); } } } diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index 40694493d4..12a751b134 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -16,9 +16,10 @@ //! can be somewhat arbitrary, an intentional decision. use crate::component::{ - InterfaceType, TypeRecordIndex, TypeTupleIndex, FLAG_MAY_ENTER, FLAG_MAY_LEAVE, - MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, + InterfaceType, TypeRecordIndex, TypeTupleIndex, TypeVariantIndex, FLAG_MAY_ENTER, + FLAG_MAY_LEAVE, MAX_FLAT_PARAMS, MAX_FLAT_RESULTS, }; +use crate::fact::core_types::CoreTypes; use crate::fact::signature::{align_to, Signature}; use crate::fact::traps::Trap; use crate::fact::{AdapterData, Context, Module, Options}; @@ -27,11 +28,15 @@ use std::collections::HashMap; use std::mem; use std::ops::Range; use wasm_encoder::{BlockType, Encode, Instruction, Instruction::*, MemArg, ValType}; +use wasmtime_component_util::DiscriminantSize; -struct Compiler<'a> { +struct Compiler<'a, 'b> { /// The module that the adapter will eventually be inserted into. module: &'a Module<'a>, + /// The type section of `module` + types: &'b mut CoreTypes, + /// Metadata about the adapter that is being compiled. adapter: &'a AdapterData, @@ -62,11 +67,16 @@ struct Compiler<'a> { lift_sig: &'a Signature, } -pub(super) fn compile(module: &Module<'_>, adapter: &AdapterData) -> (Vec, Vec<(usize, Trap)>) { - let lower_sig = &module.signature(adapter.lower.ty, Context::Lower); - let lift_sig = &module.signature(adapter.lift.ty, Context::Lift); +pub(super) fn compile( + module: &Module<'_>, + types: &mut CoreTypes, + adapter: &AdapterData, +) -> (Vec, Vec<(usize, Trap)>) { + let lower_sig = &module.signature(&adapter.lower, Context::Lower); + let lift_sig = &module.signature(&adapter.lift, Context::Lift); Compiler { module, + types, adapter, code: Vec::new(), locals: Vec::new(), @@ -94,10 +104,13 @@ enum Source<'a> { } /// Same as `Source` but for where values are translated into. -enum Destination { +enum Destination<'a> { /// This value is destined for the WebAssembly stack which means that /// results are simply pushed as we go along. - Stack, + /// + /// The types listed are the types that are expected to be on the stack at + /// the end of translation. + Stack(&'a [ValType]), /// This value is to be placed in linear memory described by `Memory`. Memory(Memory), @@ -114,6 +127,8 @@ struct Stack<'a> { /// Representation of where a value is going to be stored in linear memory. struct Memory { + /// Whether or not the `addr_local` is a 64-bit type. + memory64: bool, /// The index of the local that contains the base address of where the /// storage is happening. addr_local: u32, @@ -125,7 +140,7 @@ struct Memory { memory_idx: u32, } -impl Compiler<'_> { +impl Compiler<'_, '_> { fn compile(&mut self) -> (Vec, Vec<(usize, Trap)>) { // Check the instance flags required for this trampoline. // @@ -237,7 +252,7 @@ impl Compiler<'_> { }; let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { - Destination::Stack + Destination::Stack(&dst_flat) } else { // If there are too many parameters then space is allocated in the // destination module for the parameters via its `realloc` function. @@ -246,10 +261,10 @@ impl Compiler<'_> { }; let srcs = src - .record_field_sources(self.module, src_tys.iter().copied()) + .record_field_srcs(self.module, src_tys.iter().copied()) .zip(src_tys.iter()); let dsts = dst - .record_field_sources(self.module, dst_tys.iter().copied()) + .record_field_dsts(self.module, dst_tys.iter().copied()) .zip(dst_tys.iter()); for ((src, src_ty), (dst, dst_ty)) in srcs.zip(dsts) { self.translate(&src_ty, &src, &dst_ty, &dst); @@ -291,7 +306,7 @@ impl Compiler<'_> { }; let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { - Destination::Stack + Destination::Stack(&dst_flat) } else { // This is slightly different than `translate_params` where the // return pointer was provided by the caller of this function @@ -322,9 +337,18 @@ impl Compiler<'_> { InterfaceType::Unit => self.translate_unit(src, dst_ty, dst), InterfaceType::Bool => self.translate_bool(src, dst_ty, dst), InterfaceType::U8 => self.translate_u8(src, dst_ty, dst), + InterfaceType::S8 => self.translate_s8(src, dst_ty, dst), + InterfaceType::U16 => self.translate_u16(src, dst_ty, dst), + InterfaceType::S16 => self.translate_s16(src, dst_ty, dst), InterfaceType::U32 => self.translate_u32(src, dst_ty, dst), + InterfaceType::S32 => self.translate_s32(src, dst_ty, dst), + InterfaceType::U64 => self.translate_u64(src, dst_ty, dst), + InterfaceType::S64 => self.translate_s64(src, dst_ty, dst), + InterfaceType::Float32 => self.translate_f32(src, dst_ty, dst), + InterfaceType::Float64 => self.translate_f64(src, dst_ty, dst), InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst), InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst), + InterfaceType::Variant(v) => self.translate_variant(*v, src, dst_ty, dst), InterfaceType::String => { // consider this field used for now until this is fully @@ -362,7 +386,7 @@ impl Compiler<'_> { match dst { Destination::Memory(mem) => self.i32_store8(mem), - Destination::Stack => {} + Destination::Stack(stack) => self.stack_set(stack, ValType::I32), } } @@ -372,11 +396,67 @@ impl Compiler<'_> { self.push_dst_addr(dst); match src { Source::Memory(mem) => self.i32_load8u(mem), - Source::Stack(stack) => self.stack_get(stack, ValType::I32), + Source::Stack(stack) => { + self.stack_get(stack, ValType::I32); + self.instruction(I32Const(0xff)); + self.instruction(I32And); + } } match dst { Destination::Memory(mem) => self.i32_store8(mem), - Destination::Stack => {} + Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + } + } + + fn translate_s8(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::S8)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.i32_load8s(mem), + Source::Stack(stack) => { + self.stack_get(stack, ValType::I32); + self.instruction(I32Extend8S); + } + } + match dst { + Destination::Memory(mem) => self.i32_store8(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + } + } + + fn translate_u16(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::U16)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.i32_load16u(mem), + Source::Stack(stack) => { + self.stack_get(stack, ValType::I32); + self.instruction(I32Const(0xffff)); + self.instruction(I32And); + } + } + match dst { + Destination::Memory(mem) => self.i32_store16(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + } + } + + fn translate_s16(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::S16)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.i32_load16s(mem), + Source::Stack(stack) => { + self.stack_get(stack, ValType::I32); + self.instruction(I32Extend16S); + } + } + match dst { + Destination::Memory(mem) => self.i32_store16(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::I32), } } @@ -390,7 +470,77 @@ impl Compiler<'_> { } match dst { Destination::Memory(mem) => self.i32_store(mem), - Destination::Stack => {} + Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + } + } + + fn translate_s32(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::S32)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.i32_load(mem), + Source::Stack(stack) => self.stack_get(stack, ValType::I32), + } + match dst { + Destination::Memory(mem) => self.i32_store(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + } + } + + fn translate_u64(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::U64)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.i64_load(mem), + Source::Stack(stack) => self.stack_get(stack, ValType::I64), + } + match dst { + Destination::Memory(mem) => self.i64_store(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::I64), + } + } + + fn translate_s64(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::S64)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.i64_load(mem), + Source::Stack(stack) => self.stack_get(stack, ValType::I64), + } + match dst { + Destination::Memory(mem) => self.i64_store(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::I64), + } + } + + fn translate_f32(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::Float32)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.f32_load(mem), + Source::Stack(stack) => self.stack_get(stack, ValType::F32), + } + match dst { + Destination::Memory(mem) => self.f32_store(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::F32), + } + } + + fn translate_f64(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { + // TODO: subtyping + assert!(matches!(dst_ty, InterfaceType::Float64)); + self.push_dst_addr(dst); + match src { + Source::Memory(mem) => self.f64_load(mem), + Source::Stack(stack) => self.stack_get(stack, ValType::F64), + } + match dst { + Destination::Memory(mem) => self.f64_store(mem), + Destination::Stack(stack) => self.stack_set(stack, ValType::F64), } } @@ -415,7 +565,7 @@ impl Compiler<'_> { // fields' names let mut src_fields = HashMap::new(); for (i, src) in src - .record_field_sources(self.module, src_ty.fields.iter().map(|f| f.ty)) + .record_field_srcs(self.module, src_ty.fields.iter().map(|f| f.ty)) .enumerate() { let field = &src_ty.fields[i]; @@ -431,7 +581,7 @@ impl Compiler<'_> { // // TODO: should that lookup be fallible with subtyping? for (i, dst) in dst - .record_field_sources(self.module, dst_ty.fields.iter().map(|f| f.ty)) + .record_field_dsts(self.module, dst_ty.fields.iter().map(|f| f.ty)) .enumerate() { let field = &dst_ty.fields[i]; @@ -457,16 +607,145 @@ impl Compiler<'_> { assert_eq!(src_ty.types.len(), dst_ty.types.len()); let srcs = src - .record_field_sources(self.module, src_ty.types.iter().copied()) + .record_field_srcs(self.module, src_ty.types.iter().copied()) .zip(src_ty.types.iter()); let dsts = dst - .record_field_sources(self.module, dst_ty.types.iter().copied()) + .record_field_dsts(self.module, dst_ty.types.iter().copied()) .zip(dst_ty.types.iter()); for ((src, src_ty), (dst, dst_ty)) in srcs.zip(dsts) { self.translate(src_ty, &src, dst_ty, &dst); } } + fn translate_variant( + &mut self, + src_ty: TypeVariantIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let src_ty = &self.module.types[src_ty]; + let dst_ty = match dst_ty { + InterfaceType::Variant(t) => &self.module.types[*t], + _ => panic!("expected a variant"), + }; + + let src_disc_size = DiscriminantSize::from_count(src_ty.cases.len()).unwrap(); + let dst_disc_size = DiscriminantSize::from_count(dst_ty.cases.len()).unwrap(); + + // The outermost block is special since it has the result type of the + // translation here. That will depend on the `dst`. + let outer_block_ty = match dst { + Destination::Stack(dst_flat) => match dst_flat.len() { + 0 => BlockType::Empty, + 1 => BlockType::Result(dst_flat[0]), + _ => { + let ty = self.types.function(&[], &dst_flat); + BlockType::FunctionType(ty) + } + }, + Destination::Memory(_) => BlockType::Empty, + }; + self.instruction(Block(outer_block_ty)); + + // After the outermost block generate a new block for each of the + // remaining cases. + for _ in 0..src_ty.cases.len() - 1 { + self.instruction(Block(BlockType::Empty)); + } + + // Generate a block for an invalid variant discriminant + self.instruction(Block(BlockType::Empty)); + + // And generate one final block that we'll be jumping out of with the + // `br_table` + self.instruction(Block(BlockType::Empty)); + + // Load the discriminant + match src { + Source::Stack(s) => self.stack_get(&s.slice(0..1), ValType::I32), + Source::Memory(mem) => match src_disc_size { + DiscriminantSize::Size1 => self.i32_load8u(mem), + DiscriminantSize::Size2 => self.i32_load16u(mem), + DiscriminantSize::Size4 => self.i32_load(mem), + }, + } + + // Generate the `br_table` for the discriminant. Each case has an + // offset of 1 to skip the trapping block. + let mut targets = Vec::new(); + for i in 0..src_ty.cases.len() { + targets.push((i + 1) as u32); + } + self.instruction(BrTable(targets[..].into(), 0)); + self.instruction(End); // end the `br_table` block + + self.trap(Trap::InvalidDiscriminant); + self.instruction(End); // end the "invalid discriminant" block + + // Translate each case individually within its own block. Note that the + // iteration order here places the first case in the innermost block + // and the last case in the outermost block. This matches the order + // of the jump targets in the `br_table` instruction. + for (src_i, src_case) in src_ty.cases.iter().enumerate() { + let dst_i = dst_ty + .cases + .iter() + .position(|c| c.name == src_case.name) + .unwrap(); + let dst_case = &dst_ty.cases[dst_i]; + let dst_i = u32::try_from(dst_i).unwrap() as i32; + + // Translate the discriminant here, noting that `dst_i` may be + // different than `src_i`. + self.push_dst_addr(dst); + self.instruction(I32Const(dst_i)); + match dst { + Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32), + Destination::Memory(mem) => match dst_disc_size { + DiscriminantSize::Size1 => self.i32_store8(mem), + DiscriminantSize::Size2 => self.i32_store16(mem), + DiscriminantSize::Size4 => self.i32_store(mem), + }, + } + + // Translate the payload of this case using the various types from + // the dst/src. + let src_payload = src.payload_src(self.module, src_disc_size, &src_case.ty); + let dst_payload = dst.payload_dst(self.module, dst_disc_size, &dst_case.ty); + self.translate(&src_case.ty, &src_payload, &dst_case.ty, &dst_payload); + + // If the results of this translation were placed on the stack then + // the stack values may need to be padded with more zeros due to + // this particular case being possibly smaller than the entire + // variant. That's handled here by pushing remaining zeros after + // accounting for the discriminant pushed as well as the results of + // this individual payload. + if let Destination::Stack(payload_results) = dst_payload { + if let Destination::Stack(dst_results) = dst { + let remaining = &dst_results[1..][payload_results.len()..]; + for ty in remaining { + match ty { + ValType::I32 => self.instruction(I32Const(0)), + ValType::I64 => self.instruction(I64Const(0)), + ValType::F32 => self.instruction(F32Const(0.0)), + ValType::F64 => self.instruction(F64Const(0.0)), + _ => unreachable!(), + } + } + } + } + + // Branch to the outermost block. Note that this isn't needed for + // the outermost case since it simply falls through. + let src_len = src_ty.cases.len(); + if src_i != src_len - 1 { + self.instruction(Br((src_len - src_i - 1) as u32)); + } + self.instruction(End); // end this case's block + } + } + fn trap_if_not_flag(&mut self, flags_global: GlobalIndex, flag_to_test: i32, trap: Trap) { self.instruction(GlobalGet(flags_global.as_u32())); self.instruction(I32Const(flag_to_test)); @@ -498,17 +777,25 @@ impl Compiler<'_> { self.instruction(GlobalSet(flags_global.as_u32())); } - fn verify_aligned(&mut self, local: u32, align: usize) { + fn verify_aligned(&mut self, memory: &Memory, align: usize) { // If the alignment is 1 then everything is trivially aligned and the // check can be omitted. if align == 1 { return; } - self.instruction(LocalGet(local)); + self.instruction(LocalGet(memory.addr_local)); assert!(align.is_power_of_two()); - let mask = i32::try_from(align - 1).unwrap(); - self.instruction(I32Const(mask)); - self.instruction(I32And); + if memory.memory64 { + let mask = i64::try_from(align - 1).unwrap(); + self.instruction(I64Const(mask)); + self.instruction(I64And); + self.instruction(I64Const(0)); + self.instruction(I64Ne); + } else { + let mask = i32::try_from(align - 1).unwrap(); + self.instruction(I32Const(mask)); + self.instruction(I32And); + } self.instruction(If(BlockType::Empty)); self.trap(Trap::UnalignedPointer); self.instruction(End); @@ -524,11 +811,21 @@ impl Compiler<'_> { } assert!(align.is_power_of_two()); self.instruction(LocalGet(mem.addr_local)); - self.instruction(I32Const(mem.i32_offset())); - self.instruction(I32Add); - let mask = i32::try_from(align - 1).unwrap(); - self.instruction(I32Const(mask)); - self.instruction(I32And); + if mem.memory64 { + self.instruction(I64Const(i64::from(mem.offset))); + self.instruction(I64Add); + let mask = i64::try_from(align - 1).unwrap(); + self.instruction(I64Const(mask)); + self.instruction(I64And); + self.instruction(I64Const(0)); + self.instruction(I64Ne); + } else { + self.instruction(I32Const(mem.i32_offset())); + self.instruction(I32Add); + let mask = i32::try_from(align - 1).unwrap(); + self.instruction(I32Const(mask)); + self.instruction(I32And); + } self.instruction(If(BlockType::Empty)); self.trap(Trap::AssertFailed("pointer not aligned")); self.instruction(End); @@ -555,12 +852,14 @@ impl Compiler<'_> { fn memory_operand(&mut self, opts: &Options, addr_local: u32, align: usize) -> Memory { let memory = opts.memory.unwrap(); - self.verify_aligned(addr_local, align); - Memory { + let ret = Memory { + memory64: opts.memory64, addr_local, offset: 0, memory_idx: memory.as_u32(), - } + }; + self.verify_aligned(&ret, align); + ret } fn gen_local(&mut self, ty: ValType) -> u32 { @@ -606,6 +905,13 @@ impl Compiler<'_> { (bytes, mem::take(&mut self.traps)) } + /// Fetches the value contained with the local specified by `stack` and + /// converts it to `dst_ty`. + /// + /// This is only intended for use in primitive operations where `stack` is + /// guaranteed to have only one local. The type of the local on the stack is + /// then converted to `dst_ty` appropriately. Note that the types may be + /// different due to the "flattening" of variant types. fn stack_get(&mut self, stack: &Stack<'_>, dst_ty: ValType) { assert_eq!(stack.locals.len(), 1); let (idx, src_ty) = stack.locals[0]; @@ -646,16 +952,90 @@ impl Compiler<'_> { } } + /// Converts the top value on the WebAssembly stack which has type + /// `src_ty` to `dst_tys[0]`. + /// + /// This is only intended for conversion of primitives where the `dst_tys` + /// list is known to be of length 1. + fn stack_set(&mut self, dst_tys: &[ValType], src_ty: ValType) { + assert_eq!(dst_tys.len(), 1); + let dst_ty = dst_tys[0]; + match (src_ty, dst_ty) { + (ValType::I32, ValType::I32) + | (ValType::I64, ValType::I64) + | (ValType::F32, ValType::F32) + | (ValType::F64, ValType::F64) => {} + + (ValType::F32, ValType::I32) => self.instruction(I32ReinterpretF32), + (ValType::I32, ValType::I64) => self.instruction(I64ExtendI32U), + (ValType::F64, ValType::I64) => self.instruction(I64ReinterpretF64), + (ValType::F32, ValType::F64) => self.instruction(F64PromoteF32), + (ValType::F32, ValType::I64) => { + self.instruction(F64PromoteF32); + self.instruction(I64ReinterpretF64); + } + + // should not be possible given the `join` function for variants + (ValType::I64, ValType::I32) + | (ValType::F64, ValType::I32) + | (ValType::I32, ValType::F32) + | (ValType::I64, ValType::F32) + | (ValType::F64, ValType::F32) + | (ValType::I32, ValType::F64) + | (ValType::I64, ValType::F64) + + // not used in the component model + | (ValType::ExternRef, _) + | (_, ValType::ExternRef) + | (ValType::FuncRef, _) + | (_, ValType::FuncRef) + | (ValType::V128, _) + | (_, ValType::V128) => { + panic!("cannot get {dst_ty:?} from {src_ty:?} local"); + } + } + } + fn i32_load8u(&mut self, mem: &Memory) { self.instruction(LocalGet(mem.addr_local)); self.instruction(I32Load8_U(mem.memarg(0))); } + fn i32_load8s(&mut self, mem: &Memory) { + self.instruction(LocalGet(mem.addr_local)); + self.instruction(I32Load8_S(mem.memarg(0))); + } + + fn i32_load16u(&mut self, mem: &Memory) { + self.instruction(LocalGet(mem.addr_local)); + self.instruction(I32Load16_U(mem.memarg(1))); + } + + fn i32_load16s(&mut self, mem: &Memory) { + self.instruction(LocalGet(mem.addr_local)); + self.instruction(I32Load16_S(mem.memarg(1))); + } + fn i32_load(&mut self, mem: &Memory) { self.instruction(LocalGet(mem.addr_local)); self.instruction(I32Load(mem.memarg(2))); } + fn i64_load(&mut self, mem: &Memory) { + self.instruction(LocalGet(mem.addr_local)); + self.instruction(I64Load(mem.memarg(3))); + } + + fn f32_load(&mut self, mem: &Memory) { + self.instruction(LocalGet(mem.addr_local)); + self.instruction(F32Load(mem.memarg(2))); + } + + fn f64_load(&mut self, mem: &Memory) { + self.instruction(LocalGet(mem.addr_local)); + self.instruction(F64Load(mem.memarg(3))); + } + fn push_dst_addr(&mut self, dst: &Destination) { if let Destination::Memory(mem) = dst { self.instruction(LocalGet(mem.addr_local)); @@ -666,9 +1046,25 @@ impl Compiler<'_> { self.instruction(I32Store8(mem.memarg(0))); } + fn i32_store16(&mut self, mem: &Memory) { + self.instruction(I32Store16(mem.memarg(1))); + } + fn i32_store(&mut self, mem: &Memory) { self.instruction(I32Store(mem.memarg(2))); } + + fn i64_store(&mut self, mem: &Memory) { + self.instruction(I64Store(mem.memarg(3))); + } + + fn f32_store(&mut self, mem: &Memory) { + self.instruction(F32Store(mem.memarg(2))); + } + + fn f64_store(&mut self, mem: &Memory) { + self.instruction(F64Store(mem.memarg(3))); + } } impl<'a> Source<'a> { @@ -678,7 +1074,7 @@ impl<'a> Source<'a> { /// This will automatically slice stack-based locals to the appropriate /// width for each component type and additionally calculate the appropriate /// offset for each memory-based type. - fn record_field_sources<'b>( + fn record_field_srcs<'b>( &'b self, module: &'b Module, fields: impl IntoIterator + 'b, @@ -689,9 +1085,8 @@ impl<'a> Source<'a> { let mut offset = 0; fields.into_iter().map(move |ty| match self { Source::Memory(mem) => { - let (size, align) = module.size_align(&ty); - offset = align_to(offset, align) + size; - Source::Memory(mem.bump(offset - size)) + let mem = next_field_offset(&mut offset, module, &ty, mem); + Source::Memory(mem) } Source::Stack(stack) => { let cnt = module.flatten_types([ty]).len(); @@ -700,26 +1095,90 @@ impl<'a> Source<'a> { } }) } + + /// Returns the corresponding discriminant source and payload source f + fn payload_src( + &self, + module: &Module, + size: DiscriminantSize, + case: &InterfaceType, + ) -> Source<'a> { + match self { + Source::Stack(s) => { + let flat_len = module.flatten_types([*case]).len(); + Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) + } + Source::Memory(mem) => { + let mem = payload_offset(size, module, case, mem); + Source::Memory(mem) + } + } + } } -impl Destination { - /// Same as `Source::record_field_sources` but for destinations. - fn record_field_sources<'a>( - &'a self, - module: &'a Module, - fields: impl IntoIterator + 'a, - ) -> impl Iterator + 'a { +impl<'a> Destination<'a> { + /// Same as `Source::record_field_srcs` but for destinations. + fn record_field_dsts<'b>( + &'b self, + module: &'b Module, + fields: impl IntoIterator + 'b, + ) -> impl Iterator + 'b + where + 'a: 'b, + { let mut offset = 0; fields.into_iter().map(move |ty| match self { - // TODO: dedupe with above? Destination::Memory(mem) => { - let (size, align) = module.size_align(&ty); - offset = align_to(offset, align) + size; - Destination::Memory(mem.bump(offset - size)) + let mem = next_field_offset(&mut offset, module, &ty, mem); + Destination::Memory(mem) + } + Destination::Stack(s) => { + let cnt = module.flatten_types([ty]).len(); + offset += cnt; + Destination::Stack(&s[offset - cnt..offset]) } - Destination::Stack => Destination::Stack, }) } + + /// Returns the corresponding discriminant source and payload source f + fn payload_dst( + &self, + module: &Module, + size: DiscriminantSize, + case: &InterfaceType, + ) -> Destination { + match self { + Destination::Stack(s) => { + let flat_len = module.flatten_types([*case]).len(); + Destination::Stack(&s[1..][..flat_len]) + } + Destination::Memory(mem) => { + let mem = payload_offset(size, module, case, mem); + Destination::Memory(mem) + } + } + } +} + +fn next_field_offset( + offset: &mut usize, + module: &Module, + field: &InterfaceType, + mem: &Memory, +) -> Memory { + let (size, align) = module.size_align(field); + *offset = align_to(*offset, align) + size; + mem.bump(*offset - size) +} + +fn payload_offset( + disc_size: DiscriminantSize, + module: &Module, + case: &InterfaceType, + mem: &Memory, +) -> Memory { + let align = module.align(case); + mem.bump(align_to(disc_size.into(), align)) } impl Memory { @@ -737,6 +1196,7 @@ impl Memory { fn bump(&self, offset: usize) -> Memory { Memory { + memory64: self.memory64, addr_local: self.addr_local, memory_idx: self.memory_idx, offset: self.offset + u32::try_from(offset).unwrap(), @@ -751,13 +1211,3 @@ impl<'a> Stack<'a> { } } } - -impl Options { - fn ptr(&self) -> ValType { - if self.memory64 { - ValType::I64 - } else { - ValType::I32 - } - } -} diff --git a/crates/environ/src/fact/traps.rs b/crates/environ/src/fact/traps.rs index 6b5410ef1b..4883fb6ffa 100644 --- a/crates/environ/src/fact/traps.rs +++ b/crates/environ/src/fact/traps.rs @@ -27,6 +27,7 @@ pub enum Trap { CannotLeave, CannotEnter, UnalignedPointer, + InvalidDiscriminant, AssertFailed(&'static str), } @@ -99,6 +100,7 @@ impl fmt::Display for Trap { Trap::CannotLeave => "cannot leave instance".fmt(f), Trap::CannotEnter => "cannot enter instance".fmt(f), Trap::UnalignedPointer => "pointer not aligned correctly".fmt(f), + Trap::InvalidDiscriminant => "invalid variant discriminant".fmt(f), Trap::AssertFailed(s) => write!(f, "assertion failure: {}", s), } } diff --git a/crates/environ/src/lib.rs b/crates/environ/src/lib.rs index f74794a291..9d4c07d999 100644 --- a/crates/environ/src/lib.rs +++ b/crates/environ/src/lib.rs @@ -54,7 +54,7 @@ pub use object; #[cfg(feature = "component-model")] pub mod component; #[cfg(feature = "component-model")] -mod fact; +pub mod fact; // Reexport all of these type-level since they're quite commonly used and it's // much easier to refer to everything through one crate rather than importing diff --git a/tests/misc_testsuite/component-model/fused.wast b/tests/misc_testsuite/component-model/fused.wast index c2af88ed83..6faf0b8496 100644 --- a/tests/misc_testsuite/component-model/fused.wast +++ b/tests/misc_testsuite/component-model/fused.wast @@ -680,3 +680,396 @@ (instance $c2 (instantiate $c2 (with "r" (func $c1 "r")))) ) "unreachable") + +;; simple variant translation +(component + (type $a (variant (case "x" unit))) + (type $b (variant (case "y" unit))) + + (component $c1 + (core module $m + (func (export "r") (param i32) (result i32) + (if (i32.ne (local.get 0) (i32.const 0)) (unreachable)) + i32.const 0 + ) + ) + (core instance $m (instantiate $m)) + (func (export "r") (param $a) (result $b) (canon lift (core func $m "r"))) + ) + (component $c2 + (import "r" (func $r (param $a) (result $b))) + (core func $r (canon lower (func $r))) + + (core module $m + (import "" "r" (func $r (param i32) (result i32))) + (func $start + i32.const 0 + call $r + i32.const 0 + i32.ne + if unreachable end + ) + (start $start) + ) + (core instance (instantiate $m + (with "" (instance (export "r" (func $r)))) + )) + ) + (instance $c1 (instantiate $c1)) + (instance $c2 (instantiate $c2 (with "r" (func $c1 "r")))) +) + +;; invalid variant discriminant in a parameter +(assert_trap + (component + (type $a (variant (case "x" unit))) + + (component $c1 + (core module $m + (func (export "r") (param i32)) + ) + (core instance $m (instantiate $m)) + (func (export "r") (param $a) (canon lift (core func $m "r"))) + ) + (component $c2 + (import "r" (func $r (param $a))) + (core func $r (canon lower (func $r))) + + (core module $m + (import "" "r" (func $r (param i32))) + (func $start + i32.const 1 + call $r + ) + (start $start) + ) + (core instance (instantiate $m + (with "" (instance (export "r" (func $r)))) + )) + ) + (instance $c1 (instantiate $c1)) + (instance $c2 (instantiate $c2 (with "r" (func $c1 "r")))) + ) + "unreachable") + +;; invalid variant discriminant in a result +(assert_trap + (component + (type $a (variant (case "x" unit))) + + (component $c1 + (core module $m + (func (export "r") (result i32) i32.const 1) + ) + (core instance $m (instantiate $m)) + (func (export "r") (result $a) (canon lift (core func $m "r"))) + ) + (component $c2 + (import "r" (func $r (result $a))) + (core func $r (canon lower (func $r))) + + (core module $m + (import "" "r" (func $r (result i32))) + (func $start call $r drop) + (start $start) + ) + (core instance (instantiate $m + (with "" (instance (export "r" (func $r)))) + )) + ) + (instance $c1 (instantiate $c1)) + (instance $c2 (instantiate $c2 (with "r" (func $c1 "r")))) + ) + "unreachable") + + +;; extra bits are chopped off +(component + (component $c1 + (core module $m + (func (export "u") (param i32) + (if (i32.ne (local.get 0) (i32.const 0)) (unreachable)) + ) + (func (export "s") (param i32) + (if (i32.ne (local.get 0) (i32.const -1)) (unreachable)) + ) + ) + (core instance $m (instantiate $m)) + (func (export "u8") (param u8) (canon lift (core func $m "u"))) + (func (export "u16") (param u16) (canon lift (core func $m "u"))) + (func (export "s8") (param s8) (canon lift (core func $m "s"))) + (func (export "s16") (param s16) (canon lift (core func $m "s"))) + ) + (component $c2 + (import "" (instance $i + (export "u8" (func (param u8))) + (export "s8" (func (param s8))) + (export "u16" (func (param u16))) + (export "s16" (func (param s16))) + )) + + (core func $u8 (canon lower (func $i "u8"))) + (core func $s8 (canon lower (func $i "s8"))) + (core func $u16 (canon lower (func $i "u16"))) + (core func $s16 (canon lower (func $i "s16"))) + + (core module $m + (import "" "u8" (func $u8 (param i32))) + (import "" "s8" (func $s8 (param i32))) + (import "" "u16" (func $u16 (param i32))) + (import "" "s16" (func $s16 (param i32))) + + (func $start + (call $u8 (i32.const 0)) + (call $u8 (i32.const 0xff00)) + (call $s8 (i32.const -1)) + (call $s8 (i32.const 0xff)) + (call $s8 (i32.const 0xffff)) + + (call $u16 (i32.const 0)) + (call $u16 (i32.const 0xff0000)) + (call $s16 (i32.const -1)) + (call $s16 (i32.const 0xffff)) + (call $s16 (i32.const 0xffffff)) + ) + (start $start) + ) + (core instance (instantiate $m + (with "" (instance + (export "u8" (func $u8)) + (export "s8" (func $s8)) + (export "u16" (func $u16)) + (export "s16" (func $s16)) + )) + )) + ) + (instance $c1 (instantiate $c1)) + (instance $c2 (instantiate $c2 (with "" (instance $c1)))) +) + +;; translation of locals between different types +(component + (type $a (variant (case "a" u8) (case "b" float32))) + (type $b (variant (case "a" u16) (case "b" s64))) + (type $c (variant (case "a" u64) (case "b" float64))) + (type $d (variant (case "a" float32) (case "b" float64))) + (type $e (variant (case "a" float32) (case "b" s64))) + + (component $c1 + (core module $m + (func (export "a") (param i32 i32 i32) + (i32.eqz (local.get 0)) + if + (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) + (if (i32.ne (local.get 2) (i32.const 2)) (unreachable)) + else + (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) + (if (f32.ne (f32.reinterpret_i32 (local.get 2)) (f32.const 3)) (unreachable)) + end + ) + (func (export "b") (param i32 i32 i64) + (i32.eqz (local.get 0)) + if + (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) + (if (i64.ne (local.get 2) (i64.const 4)) (unreachable)) + else + (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) + (if (i64.ne (local.get 2) (i64.const 5)) (unreachable)) + end + ) + (func (export "c") (param i32 i32 i64) + (i32.eqz (local.get 0)) + if + (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) + (if (i64.ne (local.get 2) (i64.const 6)) (unreachable)) + else + (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) + (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 7)) (unreachable)) + end + ) + (func (export "d") (param i32 i32 i64) + (i32.eqz (local.get 0)) + if + (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) + (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 8)) (unreachable)) + else + (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) + (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 9)) (unreachable)) + end + ) + (func (export "e") (param i32 i32 i64) + (i32.eqz (local.get 0)) + if + (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) + (if (f64.ne (f64.reinterpret_i64 (local.get 2)) (f64.const 10)) (unreachable)) + else + (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) + (if (i64.ne (local.get 2) (i64.const 11)) (unreachable)) + end + ) + ) + (core instance $m (instantiate $m)) + (func (export "a") (param bool) (param $a) (canon lift (core func $m "a"))) + (func (export "b") (param bool) (param $b) (canon lift (core func $m "b"))) + (func (export "c") (param bool) (param $c) (canon lift (core func $m "c"))) + (func (export "d") (param bool) (param $d) (canon lift (core func $m "d"))) + (func (export "e") (param bool) (param $e) (canon lift (core func $m "e"))) + ) + (component $c2 + (import "" (instance $i + (export "a" (func (param bool) (param $a))) + (export "b" (func (param bool) (param $b))) + (export "c" (func (param bool) (param $c))) + (export "d" (func (param bool) (param $d))) + (export "e" (func (param bool) (param $e))) + )) + + (core func $a (canon lower (func $i "a"))) + (core func $b (canon lower (func $i "b"))) + (core func $c (canon lower (func $i "c"))) + (core func $d (canon lower (func $i "d"))) + (core func $e (canon lower (func $i "e"))) + + (core module $m + (import "" "a" (func $a (param i32 i32 i32))) + (import "" "b" (func $b (param i32 i32 i64))) + (import "" "c" (func $c (param i32 i32 i64))) + (import "" "d" (func $d (param i32 i32 i64))) + (import "" "e" (func $e (param i32 i32 i64))) + + (func $start + ;; upper bits should get masked + (call $a (i32.const 0) (i32.const 0) (i32.const 0xff_02)) + (call $a (i32.const 1) (i32.const 1) (i32.reinterpret_f32 (f32.const 3))) + + ;; upper bits should get masked + (call $b (i32.const 0) (i32.const 0) (i64.const 0xff_00_04)) + (call $b (i32.const 1) (i32.const 1) (i64.const 5)) + + (call $c (i32.const 0) (i32.const 0) (i64.const 6)) + (call $c (i32.const 1) (i32.const 1) (i64.reinterpret_f64 (f64.const 7))) + + (call $d (i32.const 0) (i32.const 0) (i64.reinterpret_f64 (f64.const 8))) + (call $d (i32.const 1) (i32.const 1) (i64.reinterpret_f64 (f64.const 9))) + + (call $e (i32.const 0) (i32.const 0) (i64.reinterpret_f64 (f64.const 10))) + (call $e (i32.const 1) (i32.const 1) (i64.const 11)) + ) + (start $start) + ) + (core instance (instantiate $m + (with "" (instance + (export "a" (func $a)) + (export "b" (func $b)) + (export "c" (func $c)) + (export "d" (func $d)) + (export "e" (func $e)) + )) + )) + ) + (instance $c1 (instantiate $c1)) + (instance $c2 (instantiate $c2 (with "" (instance $c1)))) +) + +;; different size variants +(component + (type $a (variant + (case "a" unit) + (case "b" float32) + (case "c" (tuple float32 u32)) + (case "d" (tuple float32 unit u64 u8)) + )) + + (component $c1 + (core module $m + (func (export "a") (param i32 i32 f32 i64 i32) + (if (i32.eq (local.get 0) (i32.const 0)) + (block + (if (i32.ne (local.get 1) (i32.const 0)) (unreachable)) + (if (f32.ne (local.get 2) (f32.const 0)) (unreachable)) + (if (i64.ne (local.get 3) (i64.const 0)) (unreachable)) + (if (i32.ne (local.get 4) (i32.const 0)) (unreachable)) + ) + ) + (if (i32.eq (local.get 0) (i32.const 1)) + (block + (if (i32.ne (local.get 1) (i32.const 1)) (unreachable)) + (if (f32.ne (local.get 2) (f32.const 1)) (unreachable)) + (if (i64.ne (local.get 3) (i64.const 0)) (unreachable)) + (if (i32.ne (local.get 4) (i32.const 0)) (unreachable)) + ) + ) + (if (i32.eq (local.get 0) (i32.const 2)) + (block + (if (i32.ne (local.get 1) (i32.const 2)) (unreachable)) + (if (f32.ne (local.get 2) (f32.const 2)) (unreachable)) + (if (i64.ne (local.get 3) (i64.const 2)) (unreachable)) + (if (i32.ne (local.get 4) (i32.const 0)) (unreachable)) + ) + ) + (if (i32.eq (local.get 0) (i32.const 3)) + (block + (if (i32.ne (local.get 1) (i32.const 3)) (unreachable)) + (if (f32.ne (local.get 2) (f32.const 3)) (unreachable)) + (if (i64.ne (local.get 3) (i64.const 3)) (unreachable)) + (if (i32.ne (local.get 4) (i32.const 3)) (unreachable)) + ) + ) + (if (i32.gt_u (local.get 0) (i32.const 3)) + (unreachable)) + ) + ) + (core instance $m (instantiate $m)) + (func (export "a") (param u8) (param $a) (canon lift (core func $m "a"))) + ) + (component $c2 + (import "" (instance $i + (export "a" (func (param u8) (param $a))) + )) + + (core func $a (canon lower (func $i "a"))) + + (core module $m + (import "" "a" (func $a (param i32 i32 f32 i64 i32))) + + (func $start + ;; variant a + (call $a + (i32.const 0) + (i32.const 0) + (f32.const 0) + (i64.const 0) + (i32.const 0)) + ;; variant b + (call $a + (i32.const 1) + (i32.const 1) + (f32.const 1) + (i64.const 0) + (i32.const 0)) + ;; variant c + (call $a + (i32.const 2) + (i32.const 2) + (f32.const 2) + (i64.const 2) + (i32.const 0)) + ;; variant d + (call $a + (i32.const 3) + (i32.const 3) + (f32.const 3) + (i64.const 3) + (i32.const 3)) + ) + (start $start) + ) + (core instance (instantiate $m + (with "" (instance + (export "a" (func $a)) + )) + )) + ) + (instance $c1 (instantiate $c1)) + (instance $c2 (instantiate $c2 (with "" (instance $c1)))) +)