diff --git a/Cargo.lock b/Cargo.lock index b08c039d89..4c1164bf66 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3496,7 +3496,10 @@ name = "wasmtime-environ" version = "0.40.0" dependencies = [ "anyhow", + "atty", + "clap 3.2.8", "cranelift-entity", + "env_logger 0.9.0", "gimli", "indexmap", "log", @@ -3509,6 +3512,7 @@ dependencies = [ "wasmprinter", "wasmtime-component-util", "wasmtime-types", + "wat", ] [[package]] diff --git a/crates/environ/Cargo.toml b/crates/environ/Cargo.toml index 9c4e152981..1c4a3701a4 100644 --- a/crates/environ/Cargo.toml +++ b/crates/environ/Cargo.toml @@ -26,6 +26,16 @@ wasm-encoder = { version = "0.15.0", optional = true } wasmprinter = { version = "0.2.38", optional = true } wasmtime-component-util = { path = "../component-util", version = "=0.40.0", optional = true } +[dev-dependencies] +atty = "0.2.14" +clap = { version = "3.2.8", features = ['derive'] } +env_logger = "0.9.0" +wat = "1.0.47" + +[[example]] +name = "factc" +required-features = ['component-model'] + [badges] maintenance = { status = "actively-developed" } diff --git a/crates/environ/examples/factc.rs b/crates/environ/examples/factc.rs new file mode 100644 index 0000000000..763b0cdf0c --- /dev/null +++ b/crates/environ/examples/factc.rs @@ -0,0 +1,193 @@ +use anyhow::{bail, Context, Result}; +use clap::Parser; +use std::path::PathBuf; +use std::str; +use wasmparser::{Payload, Validator, WasmFeatures}; +use wasmtime_environ::component::*; +use wasmtime_environ::fact::Module; + +/// A small helper utility to explore generated adapter modules from Wasmtime's +/// adapter fusion compiler. +/// +/// This utility takes a `*.wat` file as input which is expected to be a valid +/// WebAssembly component. The component is parsed and any type definition for a +/// component function gets a generated adapter for it as if the caller/callee +/// used that type as the adapter. +/// +/// For example with an input that looks like: +/// +/// (component +/// (type (func (param u32) (result (list u8)))) +/// ) +/// +/// This tool can be used to generate an adapter for that signature. +#[derive(Parser)] +struct Factc { + /// Whether or not debug code is inserted into the generated adapter. + #[clap(long)] + debug: bool, + + /// Whether or not the lifting options (the callee of the exported adapter) + /// uses a 64-bit memory as opposed to a 32-bit memory. + #[clap(long)] + lift64: bool, + + /// Whether or not the lowering options (the caller of the exported adapter) + /// uses a 64-bit memory as opposed to a 32-bit memory. + #[clap(long)] + lower64: bool, + + /// Whether or not a call to a `post-return` configured function is enabled + /// or not. + #[clap(long)] + post_return: bool, + + /// Whether or not to skip validation of the generated adapter module. + #[clap(long)] + skip_validate: bool, + + /// Where to place the generated adapter module. Standard output is used if + /// this is not specified. + #[clap(short, long)] + output: Option, + + /// Output the text format for WebAssembly instead of the binary format. + #[clap(short, long)] + text: bool, + + /// TODO + input: PathBuf, +} + +fn main() -> Result<()> { + Factc::parse().execute() +} + +impl Factc { + fn execute(self) -> Result<()> { + env_logger::init(); + + let mut types = ComponentTypesBuilder::default(); + + // Manufactures a unique `CoreDef` so all function imports get unique + // function imports. + let mut next_def = 0; + let mut dummy_def = || { + next_def += 1; + CoreDef::Adapter(AdapterIndex::from_u32(next_def)) + }; + + // Manufactures a `CoreExport` for a memory with the shape specified. Note + // that we can't import as many memories as functions so these are + // intentionally limited. Once a handful of memories are generated of each + // type then they start getting reused. + let mut next_memory = 0; + let mut memories32 = Vec::new(); + let mut memories64 = Vec::new(); + let mut dummy_memory = |memory64: bool| { + let dst = if memory64 { + &mut memories64 + } else { + &mut memories32 + }; + let idx = if dst.len() < 5 { + next_memory += 1; + dst.push(next_memory - 1); + next_memory - 1 + } else { + dst[0] + }; + CoreExport { + instance: RuntimeInstanceIndex::from_u32(idx), + item: ExportItem::Name(String::new()), + } + }; + + let mut adapters = Vec::new(); + let input = wat::parse_file(&self.input)?; + types.push_type_scope(); + let mut validator = Validator::new_with_features(WasmFeatures { + component_model: true, + ..Default::default() + }); + for payload in wasmparser::Parser::new(0).parse_all(&input) { + let payload = payload?; + validator.payload(&payload)?; + let section = match payload { + Payload::ComponentTypeSection(s) => s, + _ => continue, + }; + for ty in section { + let ty = types.intern_component_type(&ty?)?; + types.push_component_typedef(ty); + let ty = match ty { + TypeDef::ComponentFunc(ty) => ty, + _ => continue, + }; + adapters.push(Adapter { + lift_ty: ty, + lower_ty: ty, + lower_options: AdapterOptions { + instance: RuntimeComponentInstanceIndex::from_u32(0), + string_encoding: StringEncoding::Utf8, + memory64: self.lower64, + // Pessimistically assume that memory/realloc are going to be + // required for this trampoline and provide it. Avoids doing + // calculations to figure out whether they're necessary and + // simplifies the fuzzer here without reducing coverage within FACT + // itself. + memory: Some(dummy_memory(self.lower64)), + realloc: Some(dummy_def()), + // Lowering never allows `post-return` + post_return: None, + }, + lift_options: AdapterOptions { + instance: RuntimeComponentInstanceIndex::from_u32(1), + string_encoding: StringEncoding::Utf8, + memory64: self.lift64, + memory: Some(dummy_memory(self.lift64)), + realloc: Some(dummy_def()), + post_return: if self.post_return { + Some(dummy_def()) + } else { + None + }, + }, + func: dummy_def(), + }); + } + } + types.pop_type_scope(); + + let types = types.finish(); + let mut fact_module = Module::new(&types, self.debug); + for (i, adapter) in adapters.iter().enumerate() { + fact_module.adapt(&format!("adapter{i}"), adapter); + } + let wasm = fact_module.encode(); + Validator::new_with_features(WasmFeatures { + multi_memory: true, + memory64: true, + ..WasmFeatures::default() + }) + .validate_all(&wasm) + .context("failed to validate generated module")?; + + let output = if self.text { + wasmprinter::print_bytes(&wasm) + .context("failed to convert binary wasm to text")? + .into_bytes() + } else if self.output.is_none() && atty::is(atty::Stream::Stdout) { + bail!("cannot print binary wasm output to a terminal unless `-t` flag is passed") + } else { + wasm + }; + + match &self.output { + Some(file) => std::fs::write(file, output).context("failed to write output file")?, + None => println!("{}", str::from_utf8(&output).unwrap()), + } + + Ok(()) + } +} diff --git a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs index 3aaafca092..225efed322 100644 --- a/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs +++ b/crates/environ/fuzz/fuzz_targets/fact-valid-module.rs @@ -52,6 +52,7 @@ enum ValType { Float32, Float64, Char, + List(Box), Record(Vec), // Up to 65 flags to exercise up to 3 u32 values Flags(UsizeInRange<0, 65>), @@ -230,6 +231,10 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType { ValType::Float32 => InterfaceType::Float32, ValType::Float64 => InterfaceType::Float64, ValType::Char => InterfaceType::Char, + ValType::List(ty) => { + let ty = intern(types, ty); + InterfaceType::List(types.add_interface_type(ty)) + } ValType::Record(tys) => { let ty = TypeRecord { fields: tys diff --git a/crates/environ/src/fact.rs b/crates/environ/src/fact.rs index b7c0dd5d87..b0c24aaf70 100644 --- a/crates/environ/src/fact.rs +++ b/crates/environ/src/fact.rs @@ -304,4 +304,12 @@ impl Options { ValType::I32 } } + + fn ptr_size(&self) -> u8 { + if self.memory64 { + 8 + } else { + 4 + } + } } diff --git a/crates/environ/src/fact/signature.rs b/crates/environ/src/fact/signature.rs index 7194a36757..27313f13c9 100644 --- a/crates/environ/src/fact/signature.rs +++ b/crates/environ/src/fact/signature.rs @@ -38,14 +38,14 @@ impl Module<'_> { let ty = &self.types[options.ty]; let ptr_ty = options.ptr(); - let mut params = self.flatten_types(ty.params.iter().map(|(_, ty)| *ty)); + let mut params = self.flatten_types(options, ty.params.iter().map(|(_, ty)| *ty)); let mut params_indirect = false; if params.len() > MAX_FLAT_PARAMS { params = vec![ptr_ty]; params_indirect = true; } - let mut results = self.flatten_types([ty.result]); + let mut results = self.flatten_types(options, [ty.result]); let mut results_indirect = false; if results.len() > MAX_FLAT_RESULTS { results_indirect = true; @@ -73,18 +73,19 @@ impl Module<'_> { /// Pushes the flat version of a list of component types into a final result /// list. - pub(crate) fn flatten_types( + pub(super) fn flatten_types( &self, + opts: &Options, tys: impl IntoIterator, ) -> Vec { let mut result = Vec::new(); for ty in tys { - self.push_flat(&ty, &mut result); + self.push_flat(opts, &ty, &mut result); } result } - fn push_flat(&self, ty: &InterfaceType, dst: &mut Vec) { + fn push_flat(&self, opts: &Options, ty: &InterfaceType, dst: &mut Vec) { match ty { InterfaceType::Unit => {} @@ -103,17 +104,17 @@ impl Module<'_> { InterfaceType::Float64 => dst.push(ValType::F64), InterfaceType::String | InterfaceType::List(_) => { - dst.push(ValType::I32); - dst.push(ValType::I32); + dst.push(opts.ptr()); + dst.push(opts.ptr()); } InterfaceType::Record(r) => { for field in self.types[*r].fields.iter() { - self.push_flat(&field.ty, dst); + self.push_flat(opts, &field.ty, dst); } } InterfaceType::Tuple(t) => { for ty in self.types[*t].types.iter() { - self.push_flat(ty, dst); + self.push_flat(opts, ty, dst); } } InterfaceType::Flags(f) => { @@ -126,14 +127,14 @@ impl Module<'_> { InterfaceType::Enum(_) => dst.push(ValType::I32), InterfaceType::Option(t) => { dst.push(ValType::I32); - self.push_flat(&self.types[*t], dst); + self.push_flat(opts, &self.types[*t], dst); } InterfaceType::Variant(t) => { dst.push(ValType::I32); let pos = dst.len(); let mut tmp = Vec::new(); for case in self.types[*t].cases.iter() { - self.push_flat_variant(&case.ty, pos, &mut tmp, dst); + self.push_flat_variant(opts, &case.ty, pos, &mut tmp, dst); } } InterfaceType::Union(t) => { @@ -141,7 +142,7 @@ impl Module<'_> { let pos = dst.len(); let mut tmp = Vec::new(); for ty in self.types[*t].types.iter() { - self.push_flat_variant(ty, pos, &mut tmp, dst); + self.push_flat_variant(opts, ty, pos, &mut tmp, dst); } } InterfaceType::Expected(t) => { @@ -149,21 +150,22 @@ impl Module<'_> { let e = &self.types[*t]; let pos = dst.len(); let mut tmp = Vec::new(); - self.push_flat_variant(&e.ok, pos, &mut tmp, dst); - self.push_flat_variant(&e.err, pos, &mut tmp, dst); + self.push_flat_variant(opts, &e.ok, pos, &mut tmp, dst); + self.push_flat_variant(opts, &e.err, pos, &mut tmp, dst); } } } fn push_flat_variant( &self, + opts: &Options, ty: &InterfaceType, pos: usize, tmp: &mut Vec, dst: &mut Vec, ) { tmp.truncate(0); - self.push_flat(ty, tmp); + self.push_flat(opts, ty, tmp); for (i, a) in tmp.iter().enumerate() { match dst.get_mut(pos + i) { Some(b) => join(*a, b), @@ -182,8 +184,8 @@ impl Module<'_> { } } - pub(crate) fn align(&self, ty: &InterfaceType) -> usize { - self.size_align(ty).1 + pub(super) fn align(&self, opts: &Options, ty: &InterfaceType) -> usize { + self.size_align(opts, ty).1 } /// Returns a (size, align) pair corresponding to the byte-size and @@ -191,7 +193,7 @@ impl Module<'_> { // // TODO: this is probably inefficient to entire recalculate at all phases, // seems like it would be best to intern this in some sort of map somewhere. - pub(crate) fn size_align(&self, ty: &InterfaceType) -> (usize, usize) { + pub(super) fn size_align(&self, opts: &Options, ty: &InterfaceType) -> (usize, usize) { match ty { InterfaceType::Unit => (0, 1), InterfaceType::Bool | InterfaceType::S8 | InterfaceType::U8 => (1, 1), @@ -201,12 +203,14 @@ impl Module<'_> { | InterfaceType::Char | InterfaceType::Float32 => (4, 4), InterfaceType::S64 | InterfaceType::U64 | InterfaceType::Float64 => (8, 8), - InterfaceType::String | InterfaceType::List(_) => (8, 4), + InterfaceType::String | InterfaceType::List(_) => { + ((2 * opts.ptr_size()).into(), opts.ptr_size().into()) + } InterfaceType::Record(r) => { - self.record_size_align(self.types[*r].fields.iter().map(|f| &f.ty)) + self.record_size_align(opts, self.types[*r].fields.iter().map(|f| &f.ty)) } - InterfaceType::Tuple(t) => self.record_size_align(self.types[*t].types.iter()), + InterfaceType::Tuple(t) => self.record_size_align(opts, self.types[*t].types.iter()), InterfaceType::Flags(f) => match self.types[*f].names.len() { n if n <= 8 => (1, 1), n if n <= 16 => (2, 2), @@ -216,27 +220,28 @@ impl Module<'_> { InterfaceType::Enum(t) => self.discrim_size_align(self.types[*t].names.len()), InterfaceType::Option(t) => { let ty = &self.types[*t]; - self.variant_size_align([&InterfaceType::Unit, ty].into_iter()) + self.variant_size_align(opts, [&InterfaceType::Unit, ty].into_iter()) } InterfaceType::Variant(t) => { - self.variant_size_align(self.types[*t].cases.iter().map(|c| &c.ty)) + self.variant_size_align(opts, self.types[*t].cases.iter().map(|c| &c.ty)) } - InterfaceType::Union(t) => self.variant_size_align(self.types[*t].types.iter()), + InterfaceType::Union(t) => self.variant_size_align(opts, self.types[*t].types.iter()), InterfaceType::Expected(t) => { let e = &self.types[*t]; - self.variant_size_align([&e.ok, &e.err].into_iter()) + self.variant_size_align(opts, [&e.ok, &e.err].into_iter()) } } } - pub(crate) fn record_size_align<'a>( + pub(super) fn record_size_align<'a>( &self, + opts: &Options, fields: impl Iterator, ) -> (usize, usize) { let mut size = 0; let mut align = 1; for ty in fields { - let (fsize, falign) = self.size_align(ty); + let (fsize, falign) = self.size_align(opts, ty); size = align_to(size, falign) + fsize; align = align.max(falign); } @@ -245,12 +250,13 @@ impl Module<'_> { fn variant_size_align<'a>( &self, + opts: &Options, cases: impl ExactSizeIterator, ) -> (usize, usize) { let (discrim_size, mut align) = self.discrim_size_align(cases.len()); let mut payload_size = 0; for ty in cases { - let (csize, calign) = self.size_align(ty); + let (csize, calign) = self.size_align(opts, ty); payload_size = payload_size.max(csize); align = align.max(calign); } diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index 8276af892a..405e8489c0 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -101,7 +101,7 @@ enum Source<'a> { /// This value is stored in linear memory described by the `Memory` /// structure. - Memory(Memory), + Memory(Memory<'a>), } /// Same as `Source` but for where values are translated into. @@ -111,10 +111,10 @@ enum Destination<'a> { /// /// The types listed are the types that are expected to be on the stack at /// the end of translation. - Stack(&'a [ValType]), + Stack(&'a [ValType], &'a Options), /// This value is to be placed in linear memory described by `Memory`. - Memory(Memory), + Memory(Memory<'a>), } struct Stack<'a> { @@ -124,21 +124,20 @@ struct Stack<'a> { /// up the component value. Each list has the index of the local being /// accessed as well as the type of the local itself. locals: &'a [(u32, ValType)], + /// The lifting/lowering options for where this stack of values comes from + opts: &'a Options, } /// Representation of where a value is going to be stored in linear memory. -struct Memory { - /// Whether or not the `addr_local` is a 64-bit type. - memory64: bool, +struct Memory<'a> { + /// The lifting/lowering options with memory configuration + opts: &'a Options, /// The index of the local that contains the base address of where the /// storage is happening. addr_local: u32, /// A "static" offset that will be baked into wasm instructions for where /// memory loads/stores happen. offset: u32, - /// The index of memory in the wasm module memory index space that this - /// memory is referring to. - memory_idx: u32, } impl Compiler<'_, '_> { @@ -231,12 +230,17 @@ impl Compiler<'_, '_> { // TODO: handle subtyping assert_eq!(src_tys.len(), dst_tys.len()); - let src_flat = self.module.flatten_types(src_tys.iter().copied()); - let dst_flat = self.module.flatten_types(dst_tys.iter().copied()); + let src_flat = self + .module + .flatten_types(&self.adapter.lower, src_tys.iter().copied()); + let dst_flat = self + .module + .flatten_types(&self.adapter.lift, dst_tys.iter().copied()); let src = if src_flat.len() <= MAX_FLAT_PARAMS { Source::Stack(Stack { locals: ¶m_locals[..src_flat.len()], + opts: &self.adapter.lower, }) } else { // If there are too many parameters then that means the parameters @@ -246,18 +250,21 @@ impl Compiler<'_, '_> { assert_eq!(ty, self.adapter.lower.ptr()); let align = src_tys .iter() - .map(|t| self.module.align(t)) + .map(|t| self.module.align(&self.adapter.lower, t)) .max() .unwrap_or(1); Source::Memory(self.memory_operand(&self.adapter.lower, addr, align)) }; let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { - Destination::Stack(&dst_flat) + Destination::Stack(&dst_flat, &self.adapter.lift) } else { // If there are too many parameters then space is allocated in the // destination module for the parameters via its `realloc` function. - let (size, align) = self.module.record_size_align(dst_tys.iter()); + let (size, align) = self + .module + .record_size_align(&self.adapter.lift, dst_tys.iter()); + let size = MallocSize::Const(size); Destination::Memory(self.malloc(&self.adapter.lift, size, align)) }; @@ -287,19 +294,20 @@ impl Compiler<'_, '_> { let src_ty = self.module.types[self.adapter.lift.ty].result; let dst_ty = self.module.types[self.adapter.lower.ty].result; - let src_flat = self.module.flatten_types([src_ty]); - let dst_flat = self.module.flatten_types([dst_ty]); + let src_flat = self.module.flatten_types(&self.adapter.lift, [src_ty]); + let dst_flat = self.module.flatten_types(&self.adapter.lower, [dst_ty]); let src = if src_flat.len() <= MAX_FLAT_RESULTS { Source::Stack(Stack { locals: result_locals, + opts: &self.adapter.lift, }) } else { // The original results to read from in this case come from the // return value of the function itself. The imported function will // return a linear memory address at which the values can be read // from. - let align = self.module.align(&src_ty); + let align = self.module.align(&self.adapter.lift, &src_ty); assert_eq!(result_locals.len(), 1); let (addr, ty) = result_locals[0]; assert_eq!(ty, self.adapter.lift.ptr()); @@ -307,12 +315,12 @@ impl Compiler<'_, '_> { }; let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { - Destination::Stack(&dst_flat) + Destination::Stack(&dst_flat, &self.adapter.lower) } else { // This is slightly different than `translate_params` where the // return pointer was provided by the caller of this function // meaning the last parameter local is a pointer into linear memory. - let align = self.module.align(&dst_ty); + let align = self.module.align(&self.adapter.lower, &dst_ty); let (addr, ty) = *param_locals.last().expect("no retptr"); assert_eq!(ty, self.adapter.lower.ptr()); Destination::Memory(self.memory_operand(&self.adapter.lower, addr, align)) @@ -348,6 +356,7 @@ impl Compiler<'_, '_> { InterfaceType::Float32 => self.translate_f32(src, dst_ty, dst), InterfaceType::Float64 => self.translate_f64(src, dst_ty, dst), InterfaceType::Char => self.translate_char(src, dst_ty, dst), + InterfaceType::List(t) => self.translate_list(*t, src, dst_ty, dst), InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst), InterfaceType::Flags(f) => self.translate_flags(*f, src, dst_ty, dst), InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst), @@ -363,10 +372,6 @@ impl Compiler<'_, '_> { drop(&self.adapter.lift.string_encoding); unimplemented!("don't know how to translate strings") } - - // TODO: this needs to be filled out for all the other interface - // types. - ty => unimplemented!("don't know how to translate {ty:?}"), } } @@ -393,7 +398,7 @@ impl Compiler<'_, '_> { match dst { Destination::Memory(mem) => self.i32_store8(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } } @@ -421,7 +426,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i32_store8(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } } @@ -438,7 +443,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i32_store8(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } } @@ -466,7 +471,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i32_store16(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } } @@ -483,7 +488,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i32_store16(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } } @@ -505,7 +510,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i32_store(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } } @@ -519,7 +524,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i32_store(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } } @@ -533,7 +538,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i64_store(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I64), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I64), } } @@ -547,7 +552,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.i64_store(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::I64), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I64), } } @@ -561,7 +566,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.f32_store(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::F32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::F32), } } @@ -575,7 +580,7 @@ impl Compiler<'_, '_> { } match dst { Destination::Memory(mem) => self.f64_store(mem), - Destination::Stack(stack) => self.stack_set(stack, ValType::F64), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::F64), } } @@ -627,7 +632,402 @@ impl Compiler<'_, '_> { Destination::Memory(mem) => { self.i32_store(mem); } - Destination::Stack(stack) => self.stack_set(stack, ValType::I32), + Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), + } + } + + fn translate_list( + &mut self, + src_ty: TypeInterfaceIndex, + src: &Source<'_>, + dst_ty: &InterfaceType, + dst: &Destination, + ) { + let src_element_ty = &self.module.types[src_ty]; + let dst_element_ty = match dst_ty { + InterfaceType::List(r) => &self.module.types[*r], + _ => panic!("expected a list"), + }; + let src_opts = src.opts(); + let dst_opts = dst.opts(); + let (src_size, src_align) = self.module.size_align(src_opts, src_element_ty); + let (dst_size, dst_align) = self.module.size_align(dst_opts, dst_element_ty); + + // Load the pointer/length of this list into temporary locals. These + // will be referenced a good deal so this just makes it easier to deal + // with them consistently below rather than trying to reload from memory + // for example. + let src_ptr = self.gen_local(src_opts.ptr()); + let src_len = self.gen_local(src_opts.ptr()); + match src { + Source::Stack(s) => { + assert_eq!(s.locals.len(), 2); + self.stack_get(&s.slice(0..1), src_opts.ptr()); + self.instruction(LocalSet(src_ptr)); + self.stack_get(&s.slice(1..2), src_opts.ptr()); + self.instruction(LocalSet(src_len)); + } + Source::Memory(mem) => { + self.ptr_load(mem); + self.instruction(LocalSet(src_ptr)); + self.ptr_load(&mem.bump(src_opts.ptr_size().into())); + self.instruction(LocalSet(src_len)); + } + } + + // Create a `Memory` operand which will internally assert that the + // `src_ptr` value is properly aligned. + let src_mem = self.memory_operand(src_opts, src_ptr, src_align); + + // Next the byte size of the allocation in the destination is + // determined. Note that this is pretty tricky because pointer widths + // could be changing and otherwise everything must stay within the + // 32-bit size-space. This internally will ensure that `src_len * + // dst_size` doesn't overflow 32-bits and will place the final result in + // `dst_byte_len` where `dst_byte_len` has the appropriate type for the + // destination. + let dst_byte_len = self.gen_local(dst_opts.ptr()); + self.calculate_dst_byte_len( + src_len, + dst_byte_len, + src_opts.ptr(), + dst_opts.ptr(), + dst_size, + ); + + // Here `realloc` is invoked (in a `malloc`-like fashion) to allocate + // space for the list in the destination memory. This will also + // internally insert checks that the returned pointer is aligned + // correctly for the destination. + let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len), dst_align); + + // At this point we have aligned pointers, a length, and a byte length + // for the destination. The spec also requires this translation to + // ensure that the range of memory within the source and destination + // memories are valid. Currently though this attempts to optimize that + // somewhat at least. The thinking is that if we hit an out-of-bounds + // memory access during translation that's the same as a trap up-front. + // This means we can generally minimize up-front checks in favor of + // simply trying to load out-of-bounds memory. + // + // This doesn't mean we can avoid a check entirely though. One major + // worry here is integer overflow of the pointers in linear memory as + // they're incremented to move to the next element as part of + // translation. For example if the entire 32-bit address space were + // valid and the base pointer was `0xffff_fff0` where the size was 17 + // that should not be a valid list but "simply defer to the loop below" + // would cause a wraparound to occur and no trap would be detected. + // + // To solve this a check is inserted here that the `base + byte_len` + // calculation doesn't overflow the 32-bit address space. Note though + // that this is only done for 32-bit memories, not 64-bit memories. + // Given the iteration of the loop below the only worry is when the + // address space is 100% mapped and wraparound is possible. Otherwise if + // anything in the address space is unmapped then we're guaranteed to + // hit a trap as we march from the base pointer to the end of the array. + // It's assumed that it's impossible for a 64-bit memory to have the + // entire address space mapped, so this isn't a concern for 64-bit + // memories. + // + // Technically this is only a concern for 32-bit memories if the entire + // address space is mapped, so `memory.size` could be used to skip most + // of the check here but it's assume that the `memory.size` check is + // probably more expensive than just checking for 32-bit overflow by + // using 64-bit arithmetic. This should hypothetically be tested though! + // + // TODO: the most-optimal thing here is to probably, once per adapter, + // call `memory.size` and put that in a local. If that is not the + // maximum for a 32-bit memory then this entire bounds-check here can be + // skipped. + if !src_opts.memory64 && src_size > 0 { + self.instruction(LocalGet(src_mem.addr_local)); + self.instruction(I64ExtendI32U); + if src_size < dst_size { + // If the source byte size is less than the destination size + // then we can leverage the fact that `dst_byte_len` was already + // calculated and didn't overflow so this is also guaranteed to + // not overflow. + self.instruction(LocalGet(src_len)); + self.instruction(I64ExtendI32U); + if src_size != 1 { + self.instruction(I64Const(i64::try_from(src_size).unwrap())); + self.instruction(I64Mul); + } + } else if src_size == dst_size { + // If the source byte size is the same as the destination byte + // size then that can be reused. Note that the destination byte + // size is already guaranteed to fit in 32 bits, even if it's + // store in a 64-bit local. + self.instruction(LocalGet(dst_byte_len)); + if dst_opts.ptr() == ValType::I32 { + self.instruction(I64ExtendI32U); + } + } else { + // Otherwise if the source byte size is larger than the + // destination byte size then the source byte size needs to be + // calculated fresh here. Note, though, that the result of this + // multiplication is not checked for overflow. The reason for + // that is that the result here flows into the check below about + // overflow and if this computation overflows it should be + // guaranteed to overflow the next computation. + // + // In general what's being checked here is: + // + // src_mem.addr_local + src_len * src_size + // + // These three values are all 32-bits originally and if they're + // all assumed to be `u32::MAX` then: + // + // let max = u64::from(u32::MAX); + // let result = max + max * max; + // assert_eq!(result, 0xffffffff00000000); + // + // This means that once an upper bit is set it's guaranteed to + // stay set as part of this computation, so the multiplication + // here is left unchecked to fall through into the addition + // below. + self.instruction(LocalGet(src_len)); + self.instruction(I64ExtendI32U); + self.instruction(I64Const(i64::try_from(src_size).unwrap())); + self.instruction(I64Mul); + } + self.instruction(I64Add); + self.instruction(I64Const(32)); + self.instruction(I64ShrU); + self.instruction(I32WrapI64); + self.instruction(If(BlockType::Empty)); + self.trap(Trap::ListByteLengthOverflow); + self.instruction(End); + } + + // If the destination is a 32-bit memory then its overflow check is + // relatively simple since we've already calculated the byte length of + // the destination above and can reuse that in this check. + if !dst_opts.memory64 && dst_size > 0 { + self.instruction(LocalGet(dst_mem.addr_local)); + self.instruction(I64ExtendI32U); + self.instruction(LocalGet(dst_byte_len)); + self.instruction(I64ExtendI32U); + self.instruction(I64Add); + self.instruction(I64Const(32)); + self.instruction(I64ShrU); + self.instruction(I32WrapI64); + self.instruction(If(BlockType::Empty)); + self.trap(Trap::ListByteLengthOverflow); + self.instruction(End); + } + + // This is the main body of the loop to actually translate list types. + // Note that if both element sizes are 0 then this won't actually do + // anything so the loop is removed entirely. + if src_size > 0 || dst_size > 0 { + let cur_dst_ptr = self.gen_local(dst_opts.ptr()); + let cur_src_ptr = self.gen_local(src_opts.ptr()); + let remaining = self.gen_local(src_opts.ptr()); + + let iconst = |i: i32, ty: ValType| match ty { + ValType::I32 => I32Const(i32::try_from(i).unwrap()), + ValType::I64 => I64Const(i64::try_from(i).unwrap()), + _ => unreachable!(), + }; + let src_add = if src_opts.memory64 { I64Add } else { I32Add }; + let dst_add = if dst_opts.memory64 { I64Add } else { I32Add }; + let src_eqz = if src_opts.memory64 { I64Eqz } else { I32Eqz }; + + // This block encompasses the entire loop and is use to exit before even + // entering the loop if the list size is zero. + self.instruction(Block(BlockType::Empty)); + + // Set the `remaining` local and only continue if it's > 0 + self.instruction(LocalGet(src_len)); + self.instruction(LocalTee(remaining)); + self.instruction(src_eqz.clone()); + self.instruction(BrIf(0)); + + // Initialize the two destination pointers to their initial values + self.instruction(LocalGet(src_mem.addr_local)); + self.instruction(LocalSet(cur_src_ptr)); + self.instruction(LocalGet(dst_mem.addr_local)); + self.instruction(LocalSet(cur_dst_ptr)); + + self.instruction(Loop(BlockType::Empty)); + + // Translate the next element in the list + let element_src = Source::Memory(Memory { + opts: src_opts, + offset: 0, + addr_local: cur_src_ptr, + }); + let element_dst = Destination::Memory(Memory { + opts: dst_opts, + offset: 0, + addr_local: cur_dst_ptr, + }); + self.translate(src_element_ty, &element_src, dst_element_ty, &element_dst); + + // Update the two loop pointers + if src_size > 0 { + let src_size = i32::try_from(src_size).unwrap(); + self.instruction(LocalGet(cur_src_ptr)); + self.instruction(iconst(src_size, src_opts.ptr())); + self.instruction(src_add.clone()); + self.instruction(LocalSet(cur_src_ptr)); + } + if dst_size > 0 { + let dst_size = i32::try_from(dst_size).unwrap(); + self.instruction(LocalGet(cur_dst_ptr)); + self.instruction(iconst(dst_size, dst_opts.ptr())); + self.instruction(dst_add.clone()); + self.instruction(LocalSet(cur_dst_ptr)); + } + + // Update the remaining count, falling through to break out if it's zero + // now. + self.instruction(LocalGet(remaining)); + self.instruction(iconst(-1, src_opts.ptr())); + self.instruction(src_add.clone()); + self.instruction(LocalTee(remaining)); + self.instruction(src_eqz.clone()); + self.instruction(BrIf(0)); + self.instruction(End); // end of loop + self.instruction(End); // end of block + } + + // Store the ptr/length in the desired destination + match dst { + Destination::Stack(s, _) => { + self.instruction(LocalGet(dst_mem.addr_local)); + self.stack_set(&s[..1], dst_opts.ptr()); + self.convert_src_len_to_dst(src_len, src_opts.ptr(), dst_opts.ptr()); + self.stack_set(&s[1..], dst_opts.ptr()); + } + Destination::Memory(mem) => { + self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(dst_mem.addr_local)); + self.ptr_store(mem); + self.instruction(LocalGet(mem.addr_local)); + self.convert_src_len_to_dst(src_len, src_opts.ptr(), dst_opts.ptr()); + self.ptr_store(&mem.bump(dst_opts.ptr_size().into())); + } + } + } + + fn calculate_dst_byte_len( + &mut self, + src_len_local: u32, + dst_len_local: u32, + src_ptr_ty: ValType, + dst_ptr_ty: ValType, + dst_elt_size: usize, + ) { + // Zero-size types are easy to handle here because the byte size of the + // destination is always zero. + if dst_elt_size == 0 { + if dst_ptr_ty == ValType::I64 { + self.instruction(I64Const(0)); + } else { + self.instruction(I32Const(0)); + } + self.instruction(LocalSet(dst_len_local)); + return; + } + + // For one-byte elements in the destination the check here can be a bit + // more optimal than the general case below. In these situations if the + // source pointer type is 32-bit then we're guaranteed to not overflow, + // so the source length is simply casted to the destination's type. + // + // If the source is 64-bit then all that needs to be checked is to + // ensure that it does not have the upper 32-bits set. + if dst_elt_size == 1 { + if let ValType::I64 = src_ptr_ty { + self.instruction(LocalGet(src_len_local)); + self.instruction(I64Const(32)); + self.instruction(I64ShrU); + self.instruction(I32WrapI64); + self.instruction(If(BlockType::Empty)); + self.trap(Trap::ListByteLengthOverflow); + self.instruction(End); + } + self.convert_src_len_to_dst(src_len_local, src_ptr_ty, dst_ptr_ty); + self.instruction(LocalSet(dst_len_local)); + return; + } + + // The main check implemented by this function is to verify that + // `src_len_local` does not exceed the 32-bit range. Byte sizes for + // lists must always fit in 32-bits to get transferred to 32-bit + // memories. + self.instruction(Block(BlockType::Empty)); + self.instruction(Block(BlockType::Empty)); + self.instruction(LocalGet(src_len_local)); + match src_ptr_ty { + // The source's list length is guaranteed to be less than 32-bits + // so simply extend it up to a 64-bit type for the multiplication + // below. + ValType::I32 => self.instruction(I64ExtendI32U), + + // If the source is a 64-bit memory then if the item length doesn't + // fit in 32-bits the byte length definitly won't, so generate a + // branch to our overflow trap here if any of the upper 32-bits are set. + ValType::I64 => { + self.instruction(I64Const(32)); + self.instruction(I64ShrU); + self.instruction(I32WrapI64); + self.instruction(BrIf(0)); + self.instruction(LocalGet(src_len_local)); + } + + _ => unreachable!(), + } + + // Next perform a 64-bit multiplication with the element byte size that + // is itself guaranteed to fit in 32-bits. The result is then checked + // to see if we overflowed the 32-bit space. The two input operands to + // the multiplication are guaranteed to be 32-bits at most which means + // that this multiplication shouldn't overflow. + // + // The result of the multiplication is saved into a local as well to + // get the result afterwards. + let tmp = if dst_ptr_ty != ValType::I64 { + self.gen_local(ValType::I64) + } else { + dst_len_local + }; + self.instruction(I64Const(u32::try_from(dst_elt_size).unwrap().into())); + self.instruction(I64Mul); + self.instruction(LocalTee(tmp)); + // Branch to success if the upper 32-bits are zero, otherwise + // fall-through to the trap. + self.instruction(I64Const(32)); + self.instruction(I64ShrU); + self.instruction(I64Eqz); + self.instruction(BrIf(1)); + self.instruction(End); + self.trap(Trap::ListByteLengthOverflow); + self.instruction(End); + + // If a fresh local was used to store the result of the multiplication + // then convert it down to 32-bits which should be guaranteed to not + // lose information at this point. + if dst_ptr_ty != ValType::I64 { + self.instruction(LocalGet(tmp)); + self.instruction(I32WrapI64); + self.instruction(LocalSet(dst_len_local)); + } + } + + fn convert_src_len_to_dst( + &mut self, + src_len_local: u32, + src_ptr_ty: ValType, + dst_ptr_ty: ValType, + ) { + self.instruction(LocalGet(src_len_local)); + match (src_ptr_ty, dst_ptr_ty) { + (ValType::I32, ValType::I64) => self.instruction(I64ExtendI32U), + (ValType::I64, ValType::I32) => self.instruction(I32WrapI64), + (src, dst) => assert_eq!(src, dst), } } @@ -938,7 +1338,7 @@ impl Compiler<'_, '_> { // The outermost block is special since it has the result type of the // translation here. That will depend on the `dst`. let outer_block_ty = match dst { - Destination::Stack(dst_flat) => match dst_flat.len() { + Destination::Stack(dst_flat, _) => match dst_flat.len() { 0 => BlockType::Empty, 1 => BlockType::Result(dst_flat[0]), _ => { @@ -1004,7 +1404,7 @@ impl Compiler<'_, '_> { self.push_dst_addr(dst); self.instruction(I32Const(dst_i as i32)); match dst { - Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32), + Destination::Stack(stack, _) => self.stack_set(&stack[..1], ValType::I32), Destination::Memory(mem) => match dst_disc_size { DiscriminantSize::Size1 => self.i32_store8(mem), DiscriminantSize::Size2 => self.i32_store16(mem), @@ -1024,8 +1424,8 @@ impl Compiler<'_, '_> { // variant. That's handled here by pushing remaining zeros after // accounting for the discriminant pushed as well as the results of // this individual payload. - if let Destination::Stack(payload_results) = dst_payload { - if let Destination::Stack(dst_results) = dst { + if let Destination::Stack(payload_results, _) = dst_payload { + if let Destination::Stack(dst_results, _) = dst { let remaining = &dst_results[1..][payload_results.len()..]; for ty in remaining { match ty { @@ -1087,7 +1487,7 @@ impl Compiler<'_, '_> { } self.instruction(LocalGet(memory.addr_local)); assert!(align.is_power_of_two()); - if memory.memory64 { + if memory.opts.memory64 { let mask = i64::try_from(align - 1).unwrap(); self.instruction(I64Const(mask)); self.instruction(I64And); @@ -1107,13 +1507,13 @@ impl Compiler<'_, '_> { if !self.module.debug { return; } - let align = self.module.align(ty); + let align = self.module.align(mem.opts, ty); if align == 1 { return; } assert!(align.is_power_of_two()); self.instruction(LocalGet(mem.addr_local)); - if mem.memory64 { + if mem.opts.memory64 { self.instruction(I64Const(i64::from(mem.offset))); self.instruction(I64Add); let mask = i64::try_from(align - 1).unwrap(); @@ -1133,32 +1533,41 @@ impl Compiler<'_, '_> { self.instruction(End); } - fn malloc(&mut self, opts: &Options, size: usize, align: usize) -> Memory { + fn malloc<'a>(&mut self, opts: &'a Options, size: MallocSize, align: usize) -> Memory<'a> { let addr_local = self.gen_local(opts.ptr()); let realloc = opts.realloc.unwrap(); if opts.memory64 { self.instruction(I64Const(0)); self.instruction(I64Const(0)); self.instruction(I64Const(i64::try_from(align).unwrap())); - self.instruction(I64Const(i64::try_from(size).unwrap())); + match size { + MallocSize::Const(size) => self.instruction(I64Const(i64::try_from(size).unwrap())), + MallocSize::Local(idx) => self.instruction(LocalGet(idx)), + } } else { self.instruction(I32Const(0)); self.instruction(I32Const(0)); self.instruction(I32Const(i32::try_from(align).unwrap())); - self.instruction(I32Const(i32::try_from(size).unwrap())); + match size { + MallocSize::Const(size) => self.instruction(I32Const(i32::try_from(size).unwrap())), + MallocSize::Local(idx) => self.instruction(LocalGet(idx)), + } } self.instruction(Call(realloc.as_u32())); self.instruction(LocalSet(addr_local)); self.memory_operand(opts, addr_local, align) } - fn memory_operand(&mut self, opts: &Options, addr_local: u32, align: usize) -> Memory { - let memory = opts.memory.unwrap(); + fn memory_operand<'a>( + &mut self, + opts: &'a Options, + addr_local: u32, + align: usize, + ) -> Memory<'a> { let ret = Memory { - memory64: opts.memory64, addr_local, offset: 0, - memory_idx: memory.as_u32(), + opts, }; self.verify_aligned(&ret, align); ret @@ -1328,6 +1737,14 @@ impl Compiler<'_, '_> { self.instruction(I64Load(mem.memarg(3))); } + fn ptr_load(&mut self, mem: &Memory) { + if mem.opts.memory64 { + self.i64_load(mem); + } else { + self.i32_load(mem); + } + } + fn f32_load(&mut self, mem: &Memory) { self.instruction(LocalGet(mem.addr_local)); self.instruction(F32Load(mem.memarg(2))); @@ -1360,6 +1777,14 @@ impl Compiler<'_, '_> { self.instruction(I64Store(mem.memarg(3))); } + fn ptr_store(&mut self, mem: &Memory) { + if mem.opts.memory64 { + self.i64_store(mem); + } else { + self.i32_store(mem); + } + } + fn f32_store(&mut self, mem: &Memory) { self.instruction(F32Store(mem.memarg(2))); } @@ -1391,7 +1816,7 @@ impl<'a> Source<'a> { Source::Memory(mem) } Source::Stack(stack) => { - let cnt = module.flatten_types([ty]).len(); + let cnt = module.flatten_types(stack.opts, [ty]).len(); offset += cnt; Source::Stack(stack.slice(offset - cnt..offset)) } @@ -1407,7 +1832,7 @@ impl<'a> Source<'a> { ) -> Source<'a> { match self { Source::Stack(s) => { - let flat_len = module.flatten_types([*case]).len(); + let flat_len = module.flatten_types(s.opts, [*case]).len(); Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) } Source::Memory(mem) => { @@ -1416,6 +1841,13 @@ impl<'a> Source<'a> { } } } + + fn opts(&self) -> &'a Options { + match self { + Source::Stack(s) => s.opts, + Source::Memory(mem) => mem.opts, + } + } } impl<'a> Destination<'a> { @@ -1434,10 +1866,10 @@ impl<'a> Destination<'a> { let mem = next_field_offset(&mut offset, module, &ty, mem); Destination::Memory(mem) } - Destination::Stack(s) => { - let cnt = module.flatten_types([ty]).len(); + Destination::Stack(s, opts) => { + let cnt = module.flatten_types(opts, [ty]).len(); offset += cnt; - Destination::Stack(&s[offset - cnt..offset]) + Destination::Stack(&s[offset - cnt..offset], opts) } }) } @@ -1450,9 +1882,9 @@ impl<'a> Destination<'a> { case: &InterfaceType, ) -> Destination { match self { - Destination::Stack(s) => { - let flat_len = module.flatten_types([*case]).len(); - Destination::Stack(&s[1..][..flat_len]) + Destination::Stack(s, opts) => { + let flat_len = module.flatten_types(opts, [*case]).len(); + Destination::Stack(&s[1..][..flat_len], opts) } Destination::Memory(mem) => { let mem = payload_offset(size, module, case, mem); @@ -1460,30 +1892,37 @@ impl<'a> Destination<'a> { } } } + + fn opts(&self) -> &'a Options { + match self { + Destination::Stack(_, opts) => opts, + Destination::Memory(mem) => mem.opts, + } + } } -fn next_field_offset( +fn next_field_offset<'a>( offset: &mut usize, module: &Module, field: &InterfaceType, - mem: &Memory, -) -> Memory { - let (size, align) = module.size_align(field); + mem: &Memory<'a>, +) -> Memory<'a> { + let (size, align) = module.size_align(mem.opts, field); *offset = align_to(*offset, align) + size; mem.bump(*offset - size) } -fn payload_offset( +fn payload_offset<'a>( disc_size: DiscriminantSize, module: &Module, case: &InterfaceType, - mem: &Memory, -) -> Memory { - let align = module.align(case); + mem: &Memory<'a>, +) -> Memory<'a> { + let align = module.align(mem.opts, case); mem.bump(align_to(disc_size.into(), align)) } -impl Memory { +impl<'a> Memory<'a> { fn i32_offset(&self) -> i32 { self.offset as i32 } @@ -1492,15 +1931,14 @@ impl Memory { MemArg { offset: u64::from(self.offset), align, - memory_index: self.memory_idx, + memory_index: self.opts.memory.unwrap().as_u32(), } } - fn bump(&self, offset: usize) -> Memory { + fn bump(&self, offset: usize) -> Memory<'a> { Memory { - memory64: self.memory64, + opts: self.opts, addr_local: self.addr_local, - memory_idx: self.memory_idx, offset: self.offset + u32::try_from(offset).unwrap(), } } @@ -1510,6 +1948,7 @@ impl<'a> Stack<'a> { fn slice(&self, range: Range) -> Stack<'a> { Stack { locals: &self.locals[range], + opts: self.opts, } } } @@ -1520,3 +1959,8 @@ struct VariantCase<'a> { dst_i: u32, dst_ty: &'a InterfaceType, } + +enum MallocSize { + Const(usize), + Local(u32), +} diff --git a/crates/environ/src/fact/traps.rs b/crates/environ/src/fact/traps.rs index ccccb569ca..393194f101 100644 --- a/crates/environ/src/fact/traps.rs +++ b/crates/environ/src/fact/traps.rs @@ -29,6 +29,7 @@ pub enum Trap { UnalignedPointer, InvalidDiscriminant, InvalidChar, + ListByteLengthOverflow, AssertFailed(&'static str), } @@ -103,6 +104,7 @@ impl fmt::Display for Trap { Trap::UnalignedPointer => "pointer not aligned correctly".fmt(f), Trap::InvalidDiscriminant => "invalid variant discriminant".fmt(f), Trap::InvalidChar => "invalid char value specified".fmt(f), + Trap::ListByteLengthOverflow => "byte size of list too large for i32".fmt(f), Trap::AssertFailed(s) => write!(f, "assertion failure: {}", s), } } diff --git a/crates/wasmtime/src/component/func/options.rs b/crates/wasmtime/src/component/func/options.rs index f2df2bba5a..af939a1c9b 100644 --- a/crates/wasmtime/src/component/func/options.rs +++ b/crates/wasmtime/src/component/func/options.rs @@ -102,9 +102,13 @@ impl Options { let memory = self.memory_mut(store.0); - let result_slice = match memory.get_mut(result..).and_then(|s| s.get_mut(..new_size)) { - Some(end) => end, - None => bail!("realloc return: beyond end of memory"), + let result_slice = if new_size == 0 { + &mut [] + } else { + match memory.get_mut(result..).and_then(|s| s.get_mut(..new_size)) { + Some(end) => end, + None => bail!("realloc return: beyond end of memory"), + } }; Ok((result_slice, result)) diff --git a/crates/wasmtime/src/component/func/typed.rs b/crates/wasmtime/src/component/func/typed.rs index 6e93f36d7d..9ac129444c 100644 --- a/crates/wasmtime/src/component/func/typed.rs +++ b/crates/wasmtime/src/component/func/typed.rs @@ -846,22 +846,26 @@ fn lower_string(mem: &mut MemoryMut<'_, T>, string: &str) -> Result<(usize, u match mem.string_encoding() { StringEncoding::Utf8 => { let ptr = mem.realloc(0, 0, 1, string.len())?; - mem.as_slice_mut()[ptr..][..string.len()].copy_from_slice(string.as_bytes()); + if string.len() > 0 { + mem.as_slice_mut()[ptr..][..string.len()].copy_from_slice(string.as_bytes()); + } Ok((ptr, string.len())) } StringEncoding::Utf16 => { let size = string.len() * 2; let mut ptr = mem.realloc(0, 0, 2, size)?; - let bytes = &mut mem.as_slice_mut()[ptr..][..size]; let mut copied = 0; - for (u, bytes) in string.encode_utf16().zip(bytes.chunks_mut(2)) { - let u_bytes = u.to_le_bytes(); - bytes[0] = u_bytes[0]; - bytes[1] = u_bytes[1]; - copied += 1; - } - if (copied * 2) < size { - ptr = mem.realloc(ptr, size, 2, copied * 2)?; + if size > 0 { + let bytes = &mut mem.as_slice_mut()[ptr..][..size]; + for (u, bytes) in string.encode_utf16().zip(bytes.chunks_mut(2)) { + let u_bytes = u.to_le_bytes(); + bytes[0] = u_bytes[0]; + bytes[1] = u_bytes[1]; + copied += 1; + } + if (copied * 2) < size { + ptr = mem.realloc(ptr, size, 2, copied * 2)?; + } } Ok((ptr, copied)) } diff --git a/tests/all/component_model/func.rs b/tests/all/component_model/func.rs index 322d7a1c56..aa73a19ba9 100644 --- a/tests/all/component_model/func.rs +++ b/tests/all/component_model/func.rs @@ -1127,21 +1127,19 @@ fn some_traps() -> Result<()> { err, ); } - let err = instance(&mut store)? + instance(&mut store)? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")? .call(&mut store, (&[],)) - .unwrap_err(); - assert_oob(&err); + .unwrap(); let err = instance(&mut store)? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")? .call(&mut store, (&[1],)) .unwrap_err(); assert_oob(&err); - let err = instance(&mut store)? + instance(&mut store)? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")? .call(&mut store, ("",)) - .unwrap_err(); - assert_oob(&err); + .unwrap(); let err = instance(&mut store)? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")? .call(&mut store, ("x",)) @@ -1193,12 +1191,20 @@ fn some_traps() -> Result<()> { // For this function the first allocation, the space to store all the // arguments, is in-bounds but then all further allocations, such as for // each individual string, are all out of bounds. - let err = instance(&mut store)? + instance(&mut store)? .get_typed_func::<(&str, &str, &str, &str, &str, &str, &str, &str, &str, &str), (), _>( &mut store, "take-many-second-oob", )? .call(&mut store, ("", "", "", "", "", "", "", "", "", "")) + .unwrap(); + assert_oob(&err); + let err = instance(&mut store)? + .get_typed_func::<(&str, &str, &str, &str, &str, &str, &str, &str, &str, &str), (), _>( + &mut store, + "take-many-second-oob", + )? + .call(&mut store, ("", "", "", "", "", "", "", "", "", "x")) .unwrap_err(); assert_oob(&err); Ok(()) diff --git a/tests/misc_testsuite/component-model/adapter.wast b/tests/misc_testsuite/component-model/adapter.wast index 0f2936951b..96eaf258b2 100644 --- a/tests/misc_testsuite/component-model/adapter.wast +++ b/tests/misc_testsuite/component-model/adapter.wast @@ -111,3 +111,23 @@ )) ) "degenerate component adapter called") + +;; fiddling with 0-sized lists +(component $c + (core module $m + (func (export "x") (param i32 i32)) + (func (export "realloc") (param i32 i32 i32 i32) (result i32) + i32.const -1) + (memory (export "memory") 0) + ) + (core instance $m (instantiate $m)) + (func $f (param (list unit)) + (canon lift + (core func $m "x") + (realloc (func $m "realloc")) + (memory $m "memory") + ) + ) + (export "empty-list" (func $f)) +) +(assert_return (invoke "empty-list" (list.const)) (unit.const))