Implement fused adapters for (list T) types (#4558)

* Implement fused adapters for `(list T)` types

This commit implements one of the two remaining types for adapter
fusion, lists. This implementation is particularly tricky for a number
of reasons:

* Lists have a number of validity checks which need to be carefully
  implemented. For example the byte length of the list passed to
  allocation in the destination module could overflow the 32-bit index
  space. Additionally lists in 32-bit memories need a check that their
  final address is in-bounds in the address space.

* In the effort to go ahead and support memory64 at the lowest layers
  this is where much of the magic happens. Lists are naturally always
  stored in memory and shifting between 64/32-bit address spaces
  is done here. This notably required plumbing an `Options` around
  during flattening/size/alignment calculations due to the size/types of
  lists changing depending on the memory configuration.

I've also added a small `factc` program in this commit which should
hopefully assist in exploring and debugging adapter modules. This takes
as input a component (text or binary format) and then generates an
adapter module for all component function signatures found internally.

This commit notably does not include tests for lists. I tried to figure
out a good way to add these but I felt like there were too many cases to
test and the tests would otherwise be extremely verbose. Instead I think
the best testing strategy for this commit will be through #4537 which
should be relatively extensible to testing adapters between modules in
addition to host-based lifting/lowering.

* Improve handling of lists of 0-size types

* Skip overflow checks on byte sizes for 0-size types
* Skip the copy loop entirely when src/dst are both 0
* Skip the increments of src/dst pointers if either is 0-size

* Update semantics for zero-sized lists/strings

When a list/string has a 0-byte-size the base pointer is no longer
verified to be in-bounds to match the supposedly desired adapter
semantics where no trap happens because no turn of the loop happens.
This commit is contained in:
Alex Crichton
2022-08-01 17:02:08 -05:00
committed by GitHub
parent 586ec95c11
commit fb59de15af
12 changed files with 824 additions and 118 deletions

4
Cargo.lock generated
View File

@@ -3496,7 +3496,10 @@ name = "wasmtime-environ"
version = "0.40.0" version = "0.40.0"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"atty",
"clap 3.2.8",
"cranelift-entity", "cranelift-entity",
"env_logger 0.9.0",
"gimli", "gimli",
"indexmap", "indexmap",
"log", "log",
@@ -3509,6 +3512,7 @@ dependencies = [
"wasmprinter", "wasmprinter",
"wasmtime-component-util", "wasmtime-component-util",
"wasmtime-types", "wasmtime-types",
"wat",
] ]
[[package]] [[package]]

View File

@@ -26,6 +26,16 @@ wasm-encoder = { version = "0.15.0", optional = true }
wasmprinter = { version = "0.2.38", optional = true } wasmprinter = { version = "0.2.38", optional = true }
wasmtime-component-util = { path = "../component-util", version = "=0.40.0", optional = true } wasmtime-component-util = { path = "../component-util", version = "=0.40.0", optional = true }
[dev-dependencies]
atty = "0.2.14"
clap = { version = "3.2.8", features = ['derive'] }
env_logger = "0.9.0"
wat = "1.0.47"
[[example]]
name = "factc"
required-features = ['component-model']
[badges] [badges]
maintenance = { status = "actively-developed" } maintenance = { status = "actively-developed" }

View File

@@ -0,0 +1,193 @@
use anyhow::{bail, Context, Result};
use clap::Parser;
use std::path::PathBuf;
use std::str;
use wasmparser::{Payload, Validator, WasmFeatures};
use wasmtime_environ::component::*;
use wasmtime_environ::fact::Module;
/// A small helper utility to explore generated adapter modules from Wasmtime's
/// adapter fusion compiler.
///
/// This utility takes a `*.wat` file as input which is expected to be a valid
/// WebAssembly component. The component is parsed and any type definition for a
/// component function gets a generated adapter for it as if the caller/callee
/// used that type as the adapter.
///
/// For example with an input that looks like:
///
/// (component
/// (type (func (param u32) (result (list u8))))
/// )
///
/// This tool can be used to generate an adapter for that signature.
#[derive(Parser)]
struct Factc {
/// Whether or not debug code is inserted into the generated adapter.
#[clap(long)]
debug: bool,
/// Whether or not the lifting options (the callee of the exported adapter)
/// uses a 64-bit memory as opposed to a 32-bit memory.
#[clap(long)]
lift64: bool,
/// Whether or not the lowering options (the caller of the exported adapter)
/// uses a 64-bit memory as opposed to a 32-bit memory.
#[clap(long)]
lower64: bool,
/// Whether or not a call to a `post-return` configured function is enabled
/// or not.
#[clap(long)]
post_return: bool,
/// Whether or not to skip validation of the generated adapter module.
#[clap(long)]
skip_validate: bool,
/// Where to place the generated adapter module. Standard output is used if
/// this is not specified.
#[clap(short, long)]
output: Option<PathBuf>,
/// Output the text format for WebAssembly instead of the binary format.
#[clap(short, long)]
text: bool,
/// TODO
input: PathBuf,
}
fn main() -> Result<()> {
Factc::parse().execute()
}
impl Factc {
fn execute(self) -> Result<()> {
env_logger::init();
let mut types = ComponentTypesBuilder::default();
// Manufactures a unique `CoreDef` so all function imports get unique
// function imports.
let mut next_def = 0;
let mut dummy_def = || {
next_def += 1;
CoreDef::Adapter(AdapterIndex::from_u32(next_def))
};
// Manufactures a `CoreExport` for a memory with the shape specified. Note
// that we can't import as many memories as functions so these are
// intentionally limited. Once a handful of memories are generated of each
// type then they start getting reused.
let mut next_memory = 0;
let mut memories32 = Vec::new();
let mut memories64 = Vec::new();
let mut dummy_memory = |memory64: bool| {
let dst = if memory64 {
&mut memories64
} else {
&mut memories32
};
let idx = if dst.len() < 5 {
next_memory += 1;
dst.push(next_memory - 1);
next_memory - 1
} else {
dst[0]
};
CoreExport {
instance: RuntimeInstanceIndex::from_u32(idx),
item: ExportItem::Name(String::new()),
}
};
let mut adapters = Vec::new();
let input = wat::parse_file(&self.input)?;
types.push_type_scope();
let mut validator = Validator::new_with_features(WasmFeatures {
component_model: true,
..Default::default()
});
for payload in wasmparser::Parser::new(0).parse_all(&input) {
let payload = payload?;
validator.payload(&payload)?;
let section = match payload {
Payload::ComponentTypeSection(s) => s,
_ => continue,
};
for ty in section {
let ty = types.intern_component_type(&ty?)?;
types.push_component_typedef(ty);
let ty = match ty {
TypeDef::ComponentFunc(ty) => ty,
_ => continue,
};
adapters.push(Adapter {
lift_ty: ty,
lower_ty: ty,
lower_options: AdapterOptions {
instance: RuntimeComponentInstanceIndex::from_u32(0),
string_encoding: StringEncoding::Utf8,
memory64: self.lower64,
// Pessimistically assume that memory/realloc are going to be
// required for this trampoline and provide it. Avoids doing
// calculations to figure out whether they're necessary and
// simplifies the fuzzer here without reducing coverage within FACT
// itself.
memory: Some(dummy_memory(self.lower64)),
realloc: Some(dummy_def()),
// Lowering never allows `post-return`
post_return: None,
},
lift_options: AdapterOptions {
instance: RuntimeComponentInstanceIndex::from_u32(1),
string_encoding: StringEncoding::Utf8,
memory64: self.lift64,
memory: Some(dummy_memory(self.lift64)),
realloc: Some(dummy_def()),
post_return: if self.post_return {
Some(dummy_def())
} else {
None
},
},
func: dummy_def(),
});
}
}
types.pop_type_scope();
let types = types.finish();
let mut fact_module = Module::new(&types, self.debug);
for (i, adapter) in adapters.iter().enumerate() {
fact_module.adapt(&format!("adapter{i}"), adapter);
}
let wasm = fact_module.encode();
Validator::new_with_features(WasmFeatures {
multi_memory: true,
memory64: true,
..WasmFeatures::default()
})
.validate_all(&wasm)
.context("failed to validate generated module")?;
let output = if self.text {
wasmprinter::print_bytes(&wasm)
.context("failed to convert binary wasm to text")?
.into_bytes()
} else if self.output.is_none() && atty::is(atty::Stream::Stdout) {
bail!("cannot print binary wasm output to a terminal unless `-t` flag is passed")
} else {
wasm
};
match &self.output {
Some(file) => std::fs::write(file, output).context("failed to write output file")?,
None => println!("{}", str::from_utf8(&output).unwrap()),
}
Ok(())
}
}

View File

@@ -52,6 +52,7 @@ enum ValType {
Float32, Float32,
Float64, Float64,
Char, Char,
List(Box<ValType>),
Record(Vec<ValType>), Record(Vec<ValType>),
// Up to 65 flags to exercise up to 3 u32 values // Up to 65 flags to exercise up to 3 u32 values
Flags(UsizeInRange<0, 65>), Flags(UsizeInRange<0, 65>),
@@ -230,6 +231,10 @@ fn intern(types: &mut ComponentTypesBuilder, ty: &ValType) -> InterfaceType {
ValType::Float32 => InterfaceType::Float32, ValType::Float32 => InterfaceType::Float32,
ValType::Float64 => InterfaceType::Float64, ValType::Float64 => InterfaceType::Float64,
ValType::Char => InterfaceType::Char, ValType::Char => InterfaceType::Char,
ValType::List(ty) => {
let ty = intern(types, ty);
InterfaceType::List(types.add_interface_type(ty))
}
ValType::Record(tys) => { ValType::Record(tys) => {
let ty = TypeRecord { let ty = TypeRecord {
fields: tys fields: tys

View File

@@ -304,4 +304,12 @@ impl Options {
ValType::I32 ValType::I32
} }
} }
fn ptr_size(&self) -> u8 {
if self.memory64 {
8
} else {
4
}
}
} }

View File

@@ -38,14 +38,14 @@ impl Module<'_> {
let ty = &self.types[options.ty]; let ty = &self.types[options.ty];
let ptr_ty = options.ptr(); let ptr_ty = options.ptr();
let mut params = self.flatten_types(ty.params.iter().map(|(_, ty)| *ty)); let mut params = self.flatten_types(options, ty.params.iter().map(|(_, ty)| *ty));
let mut params_indirect = false; let mut params_indirect = false;
if params.len() > MAX_FLAT_PARAMS { if params.len() > MAX_FLAT_PARAMS {
params = vec![ptr_ty]; params = vec![ptr_ty];
params_indirect = true; params_indirect = true;
} }
let mut results = self.flatten_types([ty.result]); let mut results = self.flatten_types(options, [ty.result]);
let mut results_indirect = false; let mut results_indirect = false;
if results.len() > MAX_FLAT_RESULTS { if results.len() > MAX_FLAT_RESULTS {
results_indirect = true; results_indirect = true;
@@ -73,18 +73,19 @@ impl Module<'_> {
/// Pushes the flat version of a list of component types into a final result /// Pushes the flat version of a list of component types into a final result
/// list. /// list.
pub(crate) fn flatten_types( pub(super) fn flatten_types(
&self, &self,
opts: &Options,
tys: impl IntoIterator<Item = InterfaceType>, tys: impl IntoIterator<Item = InterfaceType>,
) -> Vec<ValType> { ) -> Vec<ValType> {
let mut result = Vec::new(); let mut result = Vec::new();
for ty in tys { for ty in tys {
self.push_flat(&ty, &mut result); self.push_flat(opts, &ty, &mut result);
} }
result result
} }
fn push_flat(&self, ty: &InterfaceType, dst: &mut Vec<ValType>) { fn push_flat(&self, opts: &Options, ty: &InterfaceType, dst: &mut Vec<ValType>) {
match ty { match ty {
InterfaceType::Unit => {} InterfaceType::Unit => {}
@@ -103,17 +104,17 @@ impl Module<'_> {
InterfaceType::Float64 => dst.push(ValType::F64), InterfaceType::Float64 => dst.push(ValType::F64),
InterfaceType::String | InterfaceType::List(_) => { InterfaceType::String | InterfaceType::List(_) => {
dst.push(ValType::I32); dst.push(opts.ptr());
dst.push(ValType::I32); dst.push(opts.ptr());
} }
InterfaceType::Record(r) => { InterfaceType::Record(r) => {
for field in self.types[*r].fields.iter() { for field in self.types[*r].fields.iter() {
self.push_flat(&field.ty, dst); self.push_flat(opts, &field.ty, dst);
} }
} }
InterfaceType::Tuple(t) => { InterfaceType::Tuple(t) => {
for ty in self.types[*t].types.iter() { for ty in self.types[*t].types.iter() {
self.push_flat(ty, dst); self.push_flat(opts, ty, dst);
} }
} }
InterfaceType::Flags(f) => { InterfaceType::Flags(f) => {
@@ -126,14 +127,14 @@ impl Module<'_> {
InterfaceType::Enum(_) => dst.push(ValType::I32), InterfaceType::Enum(_) => dst.push(ValType::I32),
InterfaceType::Option(t) => { InterfaceType::Option(t) => {
dst.push(ValType::I32); dst.push(ValType::I32);
self.push_flat(&self.types[*t], dst); self.push_flat(opts, &self.types[*t], dst);
} }
InterfaceType::Variant(t) => { InterfaceType::Variant(t) => {
dst.push(ValType::I32); dst.push(ValType::I32);
let pos = dst.len(); let pos = dst.len();
let mut tmp = Vec::new(); let mut tmp = Vec::new();
for case in self.types[*t].cases.iter() { for case in self.types[*t].cases.iter() {
self.push_flat_variant(&case.ty, pos, &mut tmp, dst); self.push_flat_variant(opts, &case.ty, pos, &mut tmp, dst);
} }
} }
InterfaceType::Union(t) => { InterfaceType::Union(t) => {
@@ -141,7 +142,7 @@ impl Module<'_> {
let pos = dst.len(); let pos = dst.len();
let mut tmp = Vec::new(); let mut tmp = Vec::new();
for ty in self.types[*t].types.iter() { for ty in self.types[*t].types.iter() {
self.push_flat_variant(ty, pos, &mut tmp, dst); self.push_flat_variant(opts, ty, pos, &mut tmp, dst);
} }
} }
InterfaceType::Expected(t) => { InterfaceType::Expected(t) => {
@@ -149,21 +150,22 @@ impl Module<'_> {
let e = &self.types[*t]; let e = &self.types[*t];
let pos = dst.len(); let pos = dst.len();
let mut tmp = Vec::new(); let mut tmp = Vec::new();
self.push_flat_variant(&e.ok, pos, &mut tmp, dst); self.push_flat_variant(opts, &e.ok, pos, &mut tmp, dst);
self.push_flat_variant(&e.err, pos, &mut tmp, dst); self.push_flat_variant(opts, &e.err, pos, &mut tmp, dst);
} }
} }
} }
fn push_flat_variant( fn push_flat_variant(
&self, &self,
opts: &Options,
ty: &InterfaceType, ty: &InterfaceType,
pos: usize, pos: usize,
tmp: &mut Vec<ValType>, tmp: &mut Vec<ValType>,
dst: &mut Vec<ValType>, dst: &mut Vec<ValType>,
) { ) {
tmp.truncate(0); tmp.truncate(0);
self.push_flat(ty, tmp); self.push_flat(opts, ty, tmp);
for (i, a) in tmp.iter().enumerate() { for (i, a) in tmp.iter().enumerate() {
match dst.get_mut(pos + i) { match dst.get_mut(pos + i) {
Some(b) => join(*a, b), Some(b) => join(*a, b),
@@ -182,8 +184,8 @@ impl Module<'_> {
} }
} }
pub(crate) fn align(&self, ty: &InterfaceType) -> usize { pub(super) fn align(&self, opts: &Options, ty: &InterfaceType) -> usize {
self.size_align(ty).1 self.size_align(opts, ty).1
} }
/// Returns a (size, align) pair corresponding to the byte-size and /// Returns a (size, align) pair corresponding to the byte-size and
@@ -191,7 +193,7 @@ impl Module<'_> {
// //
// TODO: this is probably inefficient to entire recalculate at all phases, // TODO: this is probably inefficient to entire recalculate at all phases,
// seems like it would be best to intern this in some sort of map somewhere. // seems like it would be best to intern this in some sort of map somewhere.
pub(crate) fn size_align(&self, ty: &InterfaceType) -> (usize, usize) { pub(super) fn size_align(&self, opts: &Options, ty: &InterfaceType) -> (usize, usize) {
match ty { match ty {
InterfaceType::Unit => (0, 1), InterfaceType::Unit => (0, 1),
InterfaceType::Bool | InterfaceType::S8 | InterfaceType::U8 => (1, 1), InterfaceType::Bool | InterfaceType::S8 | InterfaceType::U8 => (1, 1),
@@ -201,12 +203,14 @@ impl Module<'_> {
| InterfaceType::Char | InterfaceType::Char
| InterfaceType::Float32 => (4, 4), | InterfaceType::Float32 => (4, 4),
InterfaceType::S64 | InterfaceType::U64 | InterfaceType::Float64 => (8, 8), InterfaceType::S64 | InterfaceType::U64 | InterfaceType::Float64 => (8, 8),
InterfaceType::String | InterfaceType::List(_) => (8, 4), InterfaceType::String | InterfaceType::List(_) => {
((2 * opts.ptr_size()).into(), opts.ptr_size().into())
}
InterfaceType::Record(r) => { InterfaceType::Record(r) => {
self.record_size_align(self.types[*r].fields.iter().map(|f| &f.ty)) self.record_size_align(opts, self.types[*r].fields.iter().map(|f| &f.ty))
} }
InterfaceType::Tuple(t) => self.record_size_align(self.types[*t].types.iter()), InterfaceType::Tuple(t) => self.record_size_align(opts, self.types[*t].types.iter()),
InterfaceType::Flags(f) => match self.types[*f].names.len() { InterfaceType::Flags(f) => match self.types[*f].names.len() {
n if n <= 8 => (1, 1), n if n <= 8 => (1, 1),
n if n <= 16 => (2, 2), n if n <= 16 => (2, 2),
@@ -216,27 +220,28 @@ impl Module<'_> {
InterfaceType::Enum(t) => self.discrim_size_align(self.types[*t].names.len()), InterfaceType::Enum(t) => self.discrim_size_align(self.types[*t].names.len()),
InterfaceType::Option(t) => { InterfaceType::Option(t) => {
let ty = &self.types[*t]; let ty = &self.types[*t];
self.variant_size_align([&InterfaceType::Unit, ty].into_iter()) self.variant_size_align(opts, [&InterfaceType::Unit, ty].into_iter())
} }
InterfaceType::Variant(t) => { InterfaceType::Variant(t) => {
self.variant_size_align(self.types[*t].cases.iter().map(|c| &c.ty)) self.variant_size_align(opts, self.types[*t].cases.iter().map(|c| &c.ty))
} }
InterfaceType::Union(t) => self.variant_size_align(self.types[*t].types.iter()), InterfaceType::Union(t) => self.variant_size_align(opts, self.types[*t].types.iter()),
InterfaceType::Expected(t) => { InterfaceType::Expected(t) => {
let e = &self.types[*t]; let e = &self.types[*t];
self.variant_size_align([&e.ok, &e.err].into_iter()) self.variant_size_align(opts, [&e.ok, &e.err].into_iter())
} }
} }
} }
pub(crate) fn record_size_align<'a>( pub(super) fn record_size_align<'a>(
&self, &self,
opts: &Options,
fields: impl Iterator<Item = &'a InterfaceType>, fields: impl Iterator<Item = &'a InterfaceType>,
) -> (usize, usize) { ) -> (usize, usize) {
let mut size = 0; let mut size = 0;
let mut align = 1; let mut align = 1;
for ty in fields { for ty in fields {
let (fsize, falign) = self.size_align(ty); let (fsize, falign) = self.size_align(opts, ty);
size = align_to(size, falign) + fsize; size = align_to(size, falign) + fsize;
align = align.max(falign); align = align.max(falign);
} }
@@ -245,12 +250,13 @@ impl Module<'_> {
fn variant_size_align<'a>( fn variant_size_align<'a>(
&self, &self,
opts: &Options,
cases: impl ExactSizeIterator<Item = &'a InterfaceType>, cases: impl ExactSizeIterator<Item = &'a InterfaceType>,
) -> (usize, usize) { ) -> (usize, usize) {
let (discrim_size, mut align) = self.discrim_size_align(cases.len()); let (discrim_size, mut align) = self.discrim_size_align(cases.len());
let mut payload_size = 0; let mut payload_size = 0;
for ty in cases { for ty in cases {
let (csize, calign) = self.size_align(ty); let (csize, calign) = self.size_align(opts, ty);
payload_size = payload_size.max(csize); payload_size = payload_size.max(csize);
align = align.max(calign); align = align.max(calign);
} }

View File

@@ -101,7 +101,7 @@ enum Source<'a> {
/// This value is stored in linear memory described by the `Memory` /// This value is stored in linear memory described by the `Memory`
/// structure. /// structure.
Memory(Memory), Memory(Memory<'a>),
} }
/// Same as `Source` but for where values are translated into. /// Same as `Source` but for where values are translated into.
@@ -111,10 +111,10 @@ enum Destination<'a> {
/// ///
/// The types listed are the types that are expected to be on the stack at /// The types listed are the types that are expected to be on the stack at
/// the end of translation. /// the end of translation.
Stack(&'a [ValType]), Stack(&'a [ValType], &'a Options),
/// This value is to be placed in linear memory described by `Memory`. /// This value is to be placed in linear memory described by `Memory`.
Memory(Memory), Memory(Memory<'a>),
} }
struct Stack<'a> { struct Stack<'a> {
@@ -124,21 +124,20 @@ struct Stack<'a> {
/// up the component value. Each list has the index of the local being /// up the component value. Each list has the index of the local being
/// accessed as well as the type of the local itself. /// accessed as well as the type of the local itself.
locals: &'a [(u32, ValType)], locals: &'a [(u32, ValType)],
/// The lifting/lowering options for where this stack of values comes from
opts: &'a Options,
} }
/// Representation of where a value is going to be stored in linear memory. /// Representation of where a value is going to be stored in linear memory.
struct Memory { struct Memory<'a> {
/// Whether or not the `addr_local` is a 64-bit type. /// The lifting/lowering options with memory configuration
memory64: bool, opts: &'a Options,
/// The index of the local that contains the base address of where the /// The index of the local that contains the base address of where the
/// storage is happening. /// storage is happening.
addr_local: u32, addr_local: u32,
/// A "static" offset that will be baked into wasm instructions for where /// A "static" offset that will be baked into wasm instructions for where
/// memory loads/stores happen. /// memory loads/stores happen.
offset: u32, offset: u32,
/// The index of memory in the wasm module memory index space that this
/// memory is referring to.
memory_idx: u32,
} }
impl Compiler<'_, '_> { impl Compiler<'_, '_> {
@@ -231,12 +230,17 @@ impl Compiler<'_, '_> {
// TODO: handle subtyping // TODO: handle subtyping
assert_eq!(src_tys.len(), dst_tys.len()); assert_eq!(src_tys.len(), dst_tys.len());
let src_flat = self.module.flatten_types(src_tys.iter().copied()); let src_flat = self
let dst_flat = self.module.flatten_types(dst_tys.iter().copied()); .module
.flatten_types(&self.adapter.lower, src_tys.iter().copied());
let dst_flat = self
.module
.flatten_types(&self.adapter.lift, dst_tys.iter().copied());
let src = if src_flat.len() <= MAX_FLAT_PARAMS { let src = if src_flat.len() <= MAX_FLAT_PARAMS {
Source::Stack(Stack { Source::Stack(Stack {
locals: &param_locals[..src_flat.len()], locals: &param_locals[..src_flat.len()],
opts: &self.adapter.lower,
}) })
} else { } else {
// If there are too many parameters then that means the parameters // If there are too many parameters then that means the parameters
@@ -246,18 +250,21 @@ impl Compiler<'_, '_> {
assert_eq!(ty, self.adapter.lower.ptr()); assert_eq!(ty, self.adapter.lower.ptr());
let align = src_tys let align = src_tys
.iter() .iter()
.map(|t| self.module.align(t)) .map(|t| self.module.align(&self.adapter.lower, t))
.max() .max()
.unwrap_or(1); .unwrap_or(1);
Source::Memory(self.memory_operand(&self.adapter.lower, addr, align)) Source::Memory(self.memory_operand(&self.adapter.lower, addr, align))
}; };
let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { let dst = if dst_flat.len() <= MAX_FLAT_PARAMS {
Destination::Stack(&dst_flat) Destination::Stack(&dst_flat, &self.adapter.lift)
} else { } else {
// If there are too many parameters then space is allocated in the // If there are too many parameters then space is allocated in the
// destination module for the parameters via its `realloc` function. // destination module for the parameters via its `realloc` function.
let (size, align) = self.module.record_size_align(dst_tys.iter()); let (size, align) = self
.module
.record_size_align(&self.adapter.lift, dst_tys.iter());
let size = MallocSize::Const(size);
Destination::Memory(self.malloc(&self.adapter.lift, size, align)) Destination::Memory(self.malloc(&self.adapter.lift, size, align))
}; };
@@ -287,19 +294,20 @@ impl Compiler<'_, '_> {
let src_ty = self.module.types[self.adapter.lift.ty].result; let src_ty = self.module.types[self.adapter.lift.ty].result;
let dst_ty = self.module.types[self.adapter.lower.ty].result; let dst_ty = self.module.types[self.adapter.lower.ty].result;
let src_flat = self.module.flatten_types([src_ty]); let src_flat = self.module.flatten_types(&self.adapter.lift, [src_ty]);
let dst_flat = self.module.flatten_types([dst_ty]); let dst_flat = self.module.flatten_types(&self.adapter.lower, [dst_ty]);
let src = if src_flat.len() <= MAX_FLAT_RESULTS { let src = if src_flat.len() <= MAX_FLAT_RESULTS {
Source::Stack(Stack { Source::Stack(Stack {
locals: result_locals, locals: result_locals,
opts: &self.adapter.lift,
}) })
} else { } else {
// The original results to read from in this case come from the // The original results to read from in this case come from the
// return value of the function itself. The imported function will // return value of the function itself. The imported function will
// return a linear memory address at which the values can be read // return a linear memory address at which the values can be read
// from. // from.
let align = self.module.align(&src_ty); let align = self.module.align(&self.adapter.lift, &src_ty);
assert_eq!(result_locals.len(), 1); assert_eq!(result_locals.len(), 1);
let (addr, ty) = result_locals[0]; let (addr, ty) = result_locals[0];
assert_eq!(ty, self.adapter.lift.ptr()); assert_eq!(ty, self.adapter.lift.ptr());
@@ -307,12 +315,12 @@ impl Compiler<'_, '_> {
}; };
let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { let dst = if dst_flat.len() <= MAX_FLAT_RESULTS {
Destination::Stack(&dst_flat) Destination::Stack(&dst_flat, &self.adapter.lower)
} else { } else {
// This is slightly different than `translate_params` where the // This is slightly different than `translate_params` where the
// return pointer was provided by the caller of this function // return pointer was provided by the caller of this function
// meaning the last parameter local is a pointer into linear memory. // meaning the last parameter local is a pointer into linear memory.
let align = self.module.align(&dst_ty); let align = self.module.align(&self.adapter.lower, &dst_ty);
let (addr, ty) = *param_locals.last().expect("no retptr"); let (addr, ty) = *param_locals.last().expect("no retptr");
assert_eq!(ty, self.adapter.lower.ptr()); assert_eq!(ty, self.adapter.lower.ptr());
Destination::Memory(self.memory_operand(&self.adapter.lower, addr, align)) Destination::Memory(self.memory_operand(&self.adapter.lower, addr, align))
@@ -348,6 +356,7 @@ impl Compiler<'_, '_> {
InterfaceType::Float32 => self.translate_f32(src, dst_ty, dst), InterfaceType::Float32 => self.translate_f32(src, dst_ty, dst),
InterfaceType::Float64 => self.translate_f64(src, dst_ty, dst), InterfaceType::Float64 => self.translate_f64(src, dst_ty, dst),
InterfaceType::Char => self.translate_char(src, dst_ty, dst), InterfaceType::Char => self.translate_char(src, dst_ty, dst),
InterfaceType::List(t) => self.translate_list(*t, src, dst_ty, dst),
InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst), InterfaceType::Record(t) => self.translate_record(*t, src, dst_ty, dst),
InterfaceType::Flags(f) => self.translate_flags(*f, src, dst_ty, dst), InterfaceType::Flags(f) => self.translate_flags(*f, src, dst_ty, dst),
InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst), InterfaceType::Tuple(t) => self.translate_tuple(*t, src, dst_ty, dst),
@@ -363,10 +372,6 @@ impl Compiler<'_, '_> {
drop(&self.adapter.lift.string_encoding); drop(&self.adapter.lift.string_encoding);
unimplemented!("don't know how to translate strings") unimplemented!("don't know how to translate strings")
} }
// TODO: this needs to be filled out for all the other interface
// types.
ty => unimplemented!("don't know how to translate {ty:?}"),
} }
} }
@@ -393,7 +398,7 @@ impl Compiler<'_, '_> {
match dst { match dst {
Destination::Memory(mem) => self.i32_store8(mem), Destination::Memory(mem) => self.i32_store8(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
} }
} }
@@ -421,7 +426,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i32_store8(mem), Destination::Memory(mem) => self.i32_store8(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
} }
} }
@@ -438,7 +443,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i32_store8(mem), Destination::Memory(mem) => self.i32_store8(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
} }
} }
@@ -466,7 +471,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i32_store16(mem), Destination::Memory(mem) => self.i32_store16(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
} }
} }
@@ -483,7 +488,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i32_store16(mem), Destination::Memory(mem) => self.i32_store16(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
} }
} }
@@ -505,7 +510,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i32_store(mem), Destination::Memory(mem) => self.i32_store(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
} }
} }
@@ -519,7 +524,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i32_store(mem), Destination::Memory(mem) => self.i32_store(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
} }
} }
@@ -533,7 +538,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i64_store(mem), Destination::Memory(mem) => self.i64_store(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I64), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I64),
} }
} }
@@ -547,7 +552,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.i64_store(mem), Destination::Memory(mem) => self.i64_store(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::I64), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I64),
} }
} }
@@ -561,7 +566,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.f32_store(mem), Destination::Memory(mem) => self.f32_store(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::F32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::F32),
} }
} }
@@ -575,7 +580,7 @@ impl Compiler<'_, '_> {
} }
match dst { match dst {
Destination::Memory(mem) => self.f64_store(mem), Destination::Memory(mem) => self.f64_store(mem),
Destination::Stack(stack) => self.stack_set(stack, ValType::F64), Destination::Stack(stack, _) => self.stack_set(stack, ValType::F64),
} }
} }
@@ -627,7 +632,402 @@ impl Compiler<'_, '_> {
Destination::Memory(mem) => { Destination::Memory(mem) => {
self.i32_store(mem); self.i32_store(mem);
} }
Destination::Stack(stack) => self.stack_set(stack, ValType::I32), Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32),
}
}
fn translate_list(
&mut self,
src_ty: TypeInterfaceIndex,
src: &Source<'_>,
dst_ty: &InterfaceType,
dst: &Destination,
) {
let src_element_ty = &self.module.types[src_ty];
let dst_element_ty = match dst_ty {
InterfaceType::List(r) => &self.module.types[*r],
_ => panic!("expected a list"),
};
let src_opts = src.opts();
let dst_opts = dst.opts();
let (src_size, src_align) = self.module.size_align(src_opts, src_element_ty);
let (dst_size, dst_align) = self.module.size_align(dst_opts, dst_element_ty);
// Load the pointer/length of this list into temporary locals. These
// will be referenced a good deal so this just makes it easier to deal
// with them consistently below rather than trying to reload from memory
// for example.
let src_ptr = self.gen_local(src_opts.ptr());
let src_len = self.gen_local(src_opts.ptr());
match src {
Source::Stack(s) => {
assert_eq!(s.locals.len(), 2);
self.stack_get(&s.slice(0..1), src_opts.ptr());
self.instruction(LocalSet(src_ptr));
self.stack_get(&s.slice(1..2), src_opts.ptr());
self.instruction(LocalSet(src_len));
}
Source::Memory(mem) => {
self.ptr_load(mem);
self.instruction(LocalSet(src_ptr));
self.ptr_load(&mem.bump(src_opts.ptr_size().into()));
self.instruction(LocalSet(src_len));
}
}
// Create a `Memory` operand which will internally assert that the
// `src_ptr` value is properly aligned.
let src_mem = self.memory_operand(src_opts, src_ptr, src_align);
// Next the byte size of the allocation in the destination is
// determined. Note that this is pretty tricky because pointer widths
// could be changing and otherwise everything must stay within the
// 32-bit size-space. This internally will ensure that `src_len *
// dst_size` doesn't overflow 32-bits and will place the final result in
// `dst_byte_len` where `dst_byte_len` has the appropriate type for the
// destination.
let dst_byte_len = self.gen_local(dst_opts.ptr());
self.calculate_dst_byte_len(
src_len,
dst_byte_len,
src_opts.ptr(),
dst_opts.ptr(),
dst_size,
);
// Here `realloc` is invoked (in a `malloc`-like fashion) to allocate
// space for the list in the destination memory. This will also
// internally insert checks that the returned pointer is aligned
// correctly for the destination.
let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len), dst_align);
// At this point we have aligned pointers, a length, and a byte length
// for the destination. The spec also requires this translation to
// ensure that the range of memory within the source and destination
// memories are valid. Currently though this attempts to optimize that
// somewhat at least. The thinking is that if we hit an out-of-bounds
// memory access during translation that's the same as a trap up-front.
// This means we can generally minimize up-front checks in favor of
// simply trying to load out-of-bounds memory.
//
// This doesn't mean we can avoid a check entirely though. One major
// worry here is integer overflow of the pointers in linear memory as
// they're incremented to move to the next element as part of
// translation. For example if the entire 32-bit address space were
// valid and the base pointer was `0xffff_fff0` where the size was 17
// that should not be a valid list but "simply defer to the loop below"
// would cause a wraparound to occur and no trap would be detected.
//
// To solve this a check is inserted here that the `base + byte_len`
// calculation doesn't overflow the 32-bit address space. Note though
// that this is only done for 32-bit memories, not 64-bit memories.
// Given the iteration of the loop below the only worry is when the
// address space is 100% mapped and wraparound is possible. Otherwise if
// anything in the address space is unmapped then we're guaranteed to
// hit a trap as we march from the base pointer to the end of the array.
// It's assumed that it's impossible for a 64-bit memory to have the
// entire address space mapped, so this isn't a concern for 64-bit
// memories.
//
// Technically this is only a concern for 32-bit memories if the entire
// address space is mapped, so `memory.size` could be used to skip most
// of the check here but it's assume that the `memory.size` check is
// probably more expensive than just checking for 32-bit overflow by
// using 64-bit arithmetic. This should hypothetically be tested though!
//
// TODO: the most-optimal thing here is to probably, once per adapter,
// call `memory.size` and put that in a local. If that is not the
// maximum for a 32-bit memory then this entire bounds-check here can be
// skipped.
if !src_opts.memory64 && src_size > 0 {
self.instruction(LocalGet(src_mem.addr_local));
self.instruction(I64ExtendI32U);
if src_size < dst_size {
// If the source byte size is less than the destination size
// then we can leverage the fact that `dst_byte_len` was already
// calculated and didn't overflow so this is also guaranteed to
// not overflow.
self.instruction(LocalGet(src_len));
self.instruction(I64ExtendI32U);
if src_size != 1 {
self.instruction(I64Const(i64::try_from(src_size).unwrap()));
self.instruction(I64Mul);
}
} else if src_size == dst_size {
// If the source byte size is the same as the destination byte
// size then that can be reused. Note that the destination byte
// size is already guaranteed to fit in 32 bits, even if it's
// store in a 64-bit local.
self.instruction(LocalGet(dst_byte_len));
if dst_opts.ptr() == ValType::I32 {
self.instruction(I64ExtendI32U);
}
} else {
// Otherwise if the source byte size is larger than the
// destination byte size then the source byte size needs to be
// calculated fresh here. Note, though, that the result of this
// multiplication is not checked for overflow. The reason for
// that is that the result here flows into the check below about
// overflow and if this computation overflows it should be
// guaranteed to overflow the next computation.
//
// In general what's being checked here is:
//
// src_mem.addr_local + src_len * src_size
//
// These three values are all 32-bits originally and if they're
// all assumed to be `u32::MAX` then:
//
// let max = u64::from(u32::MAX);
// let result = max + max * max;
// assert_eq!(result, 0xffffffff00000000);
//
// This means that once an upper bit is set it's guaranteed to
// stay set as part of this computation, so the multiplication
// here is left unchecked to fall through into the addition
// below.
self.instruction(LocalGet(src_len));
self.instruction(I64ExtendI32U);
self.instruction(I64Const(i64::try_from(src_size).unwrap()));
self.instruction(I64Mul);
}
self.instruction(I64Add);
self.instruction(I64Const(32));
self.instruction(I64ShrU);
self.instruction(I32WrapI64);
self.instruction(If(BlockType::Empty));
self.trap(Trap::ListByteLengthOverflow);
self.instruction(End);
}
// If the destination is a 32-bit memory then its overflow check is
// relatively simple since we've already calculated the byte length of
// the destination above and can reuse that in this check.
if !dst_opts.memory64 && dst_size > 0 {
self.instruction(LocalGet(dst_mem.addr_local));
self.instruction(I64ExtendI32U);
self.instruction(LocalGet(dst_byte_len));
self.instruction(I64ExtendI32U);
self.instruction(I64Add);
self.instruction(I64Const(32));
self.instruction(I64ShrU);
self.instruction(I32WrapI64);
self.instruction(If(BlockType::Empty));
self.trap(Trap::ListByteLengthOverflow);
self.instruction(End);
}
// This is the main body of the loop to actually translate list types.
// Note that if both element sizes are 0 then this won't actually do
// anything so the loop is removed entirely.
if src_size > 0 || dst_size > 0 {
let cur_dst_ptr = self.gen_local(dst_opts.ptr());
let cur_src_ptr = self.gen_local(src_opts.ptr());
let remaining = self.gen_local(src_opts.ptr());
let iconst = |i: i32, ty: ValType| match ty {
ValType::I32 => I32Const(i32::try_from(i).unwrap()),
ValType::I64 => I64Const(i64::try_from(i).unwrap()),
_ => unreachable!(),
};
let src_add = if src_opts.memory64 { I64Add } else { I32Add };
let dst_add = if dst_opts.memory64 { I64Add } else { I32Add };
let src_eqz = if src_opts.memory64 { I64Eqz } else { I32Eqz };
// This block encompasses the entire loop and is use to exit before even
// entering the loop if the list size is zero.
self.instruction(Block(BlockType::Empty));
// Set the `remaining` local and only continue if it's > 0
self.instruction(LocalGet(src_len));
self.instruction(LocalTee(remaining));
self.instruction(src_eqz.clone());
self.instruction(BrIf(0));
// Initialize the two destination pointers to their initial values
self.instruction(LocalGet(src_mem.addr_local));
self.instruction(LocalSet(cur_src_ptr));
self.instruction(LocalGet(dst_mem.addr_local));
self.instruction(LocalSet(cur_dst_ptr));
self.instruction(Loop(BlockType::Empty));
// Translate the next element in the list
let element_src = Source::Memory(Memory {
opts: src_opts,
offset: 0,
addr_local: cur_src_ptr,
});
let element_dst = Destination::Memory(Memory {
opts: dst_opts,
offset: 0,
addr_local: cur_dst_ptr,
});
self.translate(src_element_ty, &element_src, dst_element_ty, &element_dst);
// Update the two loop pointers
if src_size > 0 {
let src_size = i32::try_from(src_size).unwrap();
self.instruction(LocalGet(cur_src_ptr));
self.instruction(iconst(src_size, src_opts.ptr()));
self.instruction(src_add.clone());
self.instruction(LocalSet(cur_src_ptr));
}
if dst_size > 0 {
let dst_size = i32::try_from(dst_size).unwrap();
self.instruction(LocalGet(cur_dst_ptr));
self.instruction(iconst(dst_size, dst_opts.ptr()));
self.instruction(dst_add.clone());
self.instruction(LocalSet(cur_dst_ptr));
}
// Update the remaining count, falling through to break out if it's zero
// now.
self.instruction(LocalGet(remaining));
self.instruction(iconst(-1, src_opts.ptr()));
self.instruction(src_add.clone());
self.instruction(LocalTee(remaining));
self.instruction(src_eqz.clone());
self.instruction(BrIf(0));
self.instruction(End); // end of loop
self.instruction(End); // end of block
}
// Store the ptr/length in the desired destination
match dst {
Destination::Stack(s, _) => {
self.instruction(LocalGet(dst_mem.addr_local));
self.stack_set(&s[..1], dst_opts.ptr());
self.convert_src_len_to_dst(src_len, src_opts.ptr(), dst_opts.ptr());
self.stack_set(&s[1..], dst_opts.ptr());
}
Destination::Memory(mem) => {
self.instruction(LocalGet(mem.addr_local));
self.instruction(LocalGet(dst_mem.addr_local));
self.ptr_store(mem);
self.instruction(LocalGet(mem.addr_local));
self.convert_src_len_to_dst(src_len, src_opts.ptr(), dst_opts.ptr());
self.ptr_store(&mem.bump(dst_opts.ptr_size().into()));
}
}
}
fn calculate_dst_byte_len(
&mut self,
src_len_local: u32,
dst_len_local: u32,
src_ptr_ty: ValType,
dst_ptr_ty: ValType,
dst_elt_size: usize,
) {
// Zero-size types are easy to handle here because the byte size of the
// destination is always zero.
if dst_elt_size == 0 {
if dst_ptr_ty == ValType::I64 {
self.instruction(I64Const(0));
} else {
self.instruction(I32Const(0));
}
self.instruction(LocalSet(dst_len_local));
return;
}
// For one-byte elements in the destination the check here can be a bit
// more optimal than the general case below. In these situations if the
// source pointer type is 32-bit then we're guaranteed to not overflow,
// so the source length is simply casted to the destination's type.
//
// If the source is 64-bit then all that needs to be checked is to
// ensure that it does not have the upper 32-bits set.
if dst_elt_size == 1 {
if let ValType::I64 = src_ptr_ty {
self.instruction(LocalGet(src_len_local));
self.instruction(I64Const(32));
self.instruction(I64ShrU);
self.instruction(I32WrapI64);
self.instruction(If(BlockType::Empty));
self.trap(Trap::ListByteLengthOverflow);
self.instruction(End);
}
self.convert_src_len_to_dst(src_len_local, src_ptr_ty, dst_ptr_ty);
self.instruction(LocalSet(dst_len_local));
return;
}
// The main check implemented by this function is to verify that
// `src_len_local` does not exceed the 32-bit range. Byte sizes for
// lists must always fit in 32-bits to get transferred to 32-bit
// memories.
self.instruction(Block(BlockType::Empty));
self.instruction(Block(BlockType::Empty));
self.instruction(LocalGet(src_len_local));
match src_ptr_ty {
// The source's list length is guaranteed to be less than 32-bits
// so simply extend it up to a 64-bit type for the multiplication
// below.
ValType::I32 => self.instruction(I64ExtendI32U),
// If the source is a 64-bit memory then if the item length doesn't
// fit in 32-bits the byte length definitly won't, so generate a
// branch to our overflow trap here if any of the upper 32-bits are set.
ValType::I64 => {
self.instruction(I64Const(32));
self.instruction(I64ShrU);
self.instruction(I32WrapI64);
self.instruction(BrIf(0));
self.instruction(LocalGet(src_len_local));
}
_ => unreachable!(),
}
// Next perform a 64-bit multiplication with the element byte size that
// is itself guaranteed to fit in 32-bits. The result is then checked
// to see if we overflowed the 32-bit space. The two input operands to
// the multiplication are guaranteed to be 32-bits at most which means
// that this multiplication shouldn't overflow.
//
// The result of the multiplication is saved into a local as well to
// get the result afterwards.
let tmp = if dst_ptr_ty != ValType::I64 {
self.gen_local(ValType::I64)
} else {
dst_len_local
};
self.instruction(I64Const(u32::try_from(dst_elt_size).unwrap().into()));
self.instruction(I64Mul);
self.instruction(LocalTee(tmp));
// Branch to success if the upper 32-bits are zero, otherwise
// fall-through to the trap.
self.instruction(I64Const(32));
self.instruction(I64ShrU);
self.instruction(I64Eqz);
self.instruction(BrIf(1));
self.instruction(End);
self.trap(Trap::ListByteLengthOverflow);
self.instruction(End);
// If a fresh local was used to store the result of the multiplication
// then convert it down to 32-bits which should be guaranteed to not
// lose information at this point.
if dst_ptr_ty != ValType::I64 {
self.instruction(LocalGet(tmp));
self.instruction(I32WrapI64);
self.instruction(LocalSet(dst_len_local));
}
}
fn convert_src_len_to_dst(
&mut self,
src_len_local: u32,
src_ptr_ty: ValType,
dst_ptr_ty: ValType,
) {
self.instruction(LocalGet(src_len_local));
match (src_ptr_ty, dst_ptr_ty) {
(ValType::I32, ValType::I64) => self.instruction(I64ExtendI32U),
(ValType::I64, ValType::I32) => self.instruction(I32WrapI64),
(src, dst) => assert_eq!(src, dst),
} }
} }
@@ -938,7 +1338,7 @@ impl Compiler<'_, '_> {
// The outermost block is special since it has the result type of the // The outermost block is special since it has the result type of the
// translation here. That will depend on the `dst`. // translation here. That will depend on the `dst`.
let outer_block_ty = match dst { let outer_block_ty = match dst {
Destination::Stack(dst_flat) => match dst_flat.len() { Destination::Stack(dst_flat, _) => match dst_flat.len() {
0 => BlockType::Empty, 0 => BlockType::Empty,
1 => BlockType::Result(dst_flat[0]), 1 => BlockType::Result(dst_flat[0]),
_ => { _ => {
@@ -1004,7 +1404,7 @@ impl Compiler<'_, '_> {
self.push_dst_addr(dst); self.push_dst_addr(dst);
self.instruction(I32Const(dst_i as i32)); self.instruction(I32Const(dst_i as i32));
match dst { match dst {
Destination::Stack(stack) => self.stack_set(&stack[..1], ValType::I32), Destination::Stack(stack, _) => self.stack_set(&stack[..1], ValType::I32),
Destination::Memory(mem) => match dst_disc_size { Destination::Memory(mem) => match dst_disc_size {
DiscriminantSize::Size1 => self.i32_store8(mem), DiscriminantSize::Size1 => self.i32_store8(mem),
DiscriminantSize::Size2 => self.i32_store16(mem), DiscriminantSize::Size2 => self.i32_store16(mem),
@@ -1024,8 +1424,8 @@ impl Compiler<'_, '_> {
// variant. That's handled here by pushing remaining zeros after // variant. That's handled here by pushing remaining zeros after
// accounting for the discriminant pushed as well as the results of // accounting for the discriminant pushed as well as the results of
// this individual payload. // this individual payload.
if let Destination::Stack(payload_results) = dst_payload { if let Destination::Stack(payload_results, _) = dst_payload {
if let Destination::Stack(dst_results) = dst { if let Destination::Stack(dst_results, _) = dst {
let remaining = &dst_results[1..][payload_results.len()..]; let remaining = &dst_results[1..][payload_results.len()..];
for ty in remaining { for ty in remaining {
match ty { match ty {
@@ -1087,7 +1487,7 @@ impl Compiler<'_, '_> {
} }
self.instruction(LocalGet(memory.addr_local)); self.instruction(LocalGet(memory.addr_local));
assert!(align.is_power_of_two()); assert!(align.is_power_of_two());
if memory.memory64 { if memory.opts.memory64 {
let mask = i64::try_from(align - 1).unwrap(); let mask = i64::try_from(align - 1).unwrap();
self.instruction(I64Const(mask)); self.instruction(I64Const(mask));
self.instruction(I64And); self.instruction(I64And);
@@ -1107,13 +1507,13 @@ impl Compiler<'_, '_> {
if !self.module.debug { if !self.module.debug {
return; return;
} }
let align = self.module.align(ty); let align = self.module.align(mem.opts, ty);
if align == 1 { if align == 1 {
return; return;
} }
assert!(align.is_power_of_two()); assert!(align.is_power_of_two());
self.instruction(LocalGet(mem.addr_local)); self.instruction(LocalGet(mem.addr_local));
if mem.memory64 { if mem.opts.memory64 {
self.instruction(I64Const(i64::from(mem.offset))); self.instruction(I64Const(i64::from(mem.offset)));
self.instruction(I64Add); self.instruction(I64Add);
let mask = i64::try_from(align - 1).unwrap(); let mask = i64::try_from(align - 1).unwrap();
@@ -1133,32 +1533,41 @@ impl Compiler<'_, '_> {
self.instruction(End); self.instruction(End);
} }
fn malloc(&mut self, opts: &Options, size: usize, align: usize) -> Memory { fn malloc<'a>(&mut self, opts: &'a Options, size: MallocSize, align: usize) -> Memory<'a> {
let addr_local = self.gen_local(opts.ptr()); let addr_local = self.gen_local(opts.ptr());
let realloc = opts.realloc.unwrap(); let realloc = opts.realloc.unwrap();
if opts.memory64 { if opts.memory64 {
self.instruction(I64Const(0)); self.instruction(I64Const(0));
self.instruction(I64Const(0)); self.instruction(I64Const(0));
self.instruction(I64Const(i64::try_from(align).unwrap())); self.instruction(I64Const(i64::try_from(align).unwrap()));
self.instruction(I64Const(i64::try_from(size).unwrap())); match size {
MallocSize::Const(size) => self.instruction(I64Const(i64::try_from(size).unwrap())),
MallocSize::Local(idx) => self.instruction(LocalGet(idx)),
}
} else { } else {
self.instruction(I32Const(0)); self.instruction(I32Const(0));
self.instruction(I32Const(0)); self.instruction(I32Const(0));
self.instruction(I32Const(i32::try_from(align).unwrap())); self.instruction(I32Const(i32::try_from(align).unwrap()));
self.instruction(I32Const(i32::try_from(size).unwrap())); match size {
MallocSize::Const(size) => self.instruction(I32Const(i32::try_from(size).unwrap())),
MallocSize::Local(idx) => self.instruction(LocalGet(idx)),
}
} }
self.instruction(Call(realloc.as_u32())); self.instruction(Call(realloc.as_u32()));
self.instruction(LocalSet(addr_local)); self.instruction(LocalSet(addr_local));
self.memory_operand(opts, addr_local, align) self.memory_operand(opts, addr_local, align)
} }
fn memory_operand(&mut self, opts: &Options, addr_local: u32, align: usize) -> Memory { fn memory_operand<'a>(
let memory = opts.memory.unwrap(); &mut self,
opts: &'a Options,
addr_local: u32,
align: usize,
) -> Memory<'a> {
let ret = Memory { let ret = Memory {
memory64: opts.memory64,
addr_local, addr_local,
offset: 0, offset: 0,
memory_idx: memory.as_u32(), opts,
}; };
self.verify_aligned(&ret, align); self.verify_aligned(&ret, align);
ret ret
@@ -1328,6 +1737,14 @@ impl Compiler<'_, '_> {
self.instruction(I64Load(mem.memarg(3))); self.instruction(I64Load(mem.memarg(3)));
} }
fn ptr_load(&mut self, mem: &Memory) {
if mem.opts.memory64 {
self.i64_load(mem);
} else {
self.i32_load(mem);
}
}
fn f32_load(&mut self, mem: &Memory) { fn f32_load(&mut self, mem: &Memory) {
self.instruction(LocalGet(mem.addr_local)); self.instruction(LocalGet(mem.addr_local));
self.instruction(F32Load(mem.memarg(2))); self.instruction(F32Load(mem.memarg(2)));
@@ -1360,6 +1777,14 @@ impl Compiler<'_, '_> {
self.instruction(I64Store(mem.memarg(3))); self.instruction(I64Store(mem.memarg(3)));
} }
fn ptr_store(&mut self, mem: &Memory) {
if mem.opts.memory64 {
self.i64_store(mem);
} else {
self.i32_store(mem);
}
}
fn f32_store(&mut self, mem: &Memory) { fn f32_store(&mut self, mem: &Memory) {
self.instruction(F32Store(mem.memarg(2))); self.instruction(F32Store(mem.memarg(2)));
} }
@@ -1391,7 +1816,7 @@ impl<'a> Source<'a> {
Source::Memory(mem) Source::Memory(mem)
} }
Source::Stack(stack) => { Source::Stack(stack) => {
let cnt = module.flatten_types([ty]).len(); let cnt = module.flatten_types(stack.opts, [ty]).len();
offset += cnt; offset += cnt;
Source::Stack(stack.slice(offset - cnt..offset)) Source::Stack(stack.slice(offset - cnt..offset))
} }
@@ -1407,7 +1832,7 @@ impl<'a> Source<'a> {
) -> Source<'a> { ) -> Source<'a> {
match self { match self {
Source::Stack(s) => { Source::Stack(s) => {
let flat_len = module.flatten_types([*case]).len(); let flat_len = module.flatten_types(s.opts, [*case]).len();
Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len)) Source::Stack(s.slice(1..s.locals.len()).slice(0..flat_len))
} }
Source::Memory(mem) => { Source::Memory(mem) => {
@@ -1416,6 +1841,13 @@ impl<'a> Source<'a> {
} }
} }
} }
fn opts(&self) -> &'a Options {
match self {
Source::Stack(s) => s.opts,
Source::Memory(mem) => mem.opts,
}
}
} }
impl<'a> Destination<'a> { impl<'a> Destination<'a> {
@@ -1434,10 +1866,10 @@ impl<'a> Destination<'a> {
let mem = next_field_offset(&mut offset, module, &ty, mem); let mem = next_field_offset(&mut offset, module, &ty, mem);
Destination::Memory(mem) Destination::Memory(mem)
} }
Destination::Stack(s) => { Destination::Stack(s, opts) => {
let cnt = module.flatten_types([ty]).len(); let cnt = module.flatten_types(opts, [ty]).len();
offset += cnt; offset += cnt;
Destination::Stack(&s[offset - cnt..offset]) Destination::Stack(&s[offset - cnt..offset], opts)
} }
}) })
} }
@@ -1450,9 +1882,9 @@ impl<'a> Destination<'a> {
case: &InterfaceType, case: &InterfaceType,
) -> Destination { ) -> Destination {
match self { match self {
Destination::Stack(s) => { Destination::Stack(s, opts) => {
let flat_len = module.flatten_types([*case]).len(); let flat_len = module.flatten_types(opts, [*case]).len();
Destination::Stack(&s[1..][..flat_len]) Destination::Stack(&s[1..][..flat_len], opts)
} }
Destination::Memory(mem) => { Destination::Memory(mem) => {
let mem = payload_offset(size, module, case, mem); let mem = payload_offset(size, module, case, mem);
@@ -1460,30 +1892,37 @@ impl<'a> Destination<'a> {
} }
} }
} }
fn opts(&self) -> &'a Options {
match self {
Destination::Stack(_, opts) => opts,
Destination::Memory(mem) => mem.opts,
}
}
} }
fn next_field_offset( fn next_field_offset<'a>(
offset: &mut usize, offset: &mut usize,
module: &Module, module: &Module,
field: &InterfaceType, field: &InterfaceType,
mem: &Memory, mem: &Memory<'a>,
) -> Memory { ) -> Memory<'a> {
let (size, align) = module.size_align(field); let (size, align) = module.size_align(mem.opts, field);
*offset = align_to(*offset, align) + size; *offset = align_to(*offset, align) + size;
mem.bump(*offset - size) mem.bump(*offset - size)
} }
fn payload_offset( fn payload_offset<'a>(
disc_size: DiscriminantSize, disc_size: DiscriminantSize,
module: &Module, module: &Module,
case: &InterfaceType, case: &InterfaceType,
mem: &Memory, mem: &Memory<'a>,
) -> Memory { ) -> Memory<'a> {
let align = module.align(case); let align = module.align(mem.opts, case);
mem.bump(align_to(disc_size.into(), align)) mem.bump(align_to(disc_size.into(), align))
} }
impl Memory { impl<'a> Memory<'a> {
fn i32_offset(&self) -> i32 { fn i32_offset(&self) -> i32 {
self.offset as i32 self.offset as i32
} }
@@ -1492,15 +1931,14 @@ impl Memory {
MemArg { MemArg {
offset: u64::from(self.offset), offset: u64::from(self.offset),
align, align,
memory_index: self.memory_idx, memory_index: self.opts.memory.unwrap().as_u32(),
} }
} }
fn bump(&self, offset: usize) -> Memory { fn bump(&self, offset: usize) -> Memory<'a> {
Memory { Memory {
memory64: self.memory64, opts: self.opts,
addr_local: self.addr_local, addr_local: self.addr_local,
memory_idx: self.memory_idx,
offset: self.offset + u32::try_from(offset).unwrap(), offset: self.offset + u32::try_from(offset).unwrap(),
} }
} }
@@ -1510,6 +1948,7 @@ impl<'a> Stack<'a> {
fn slice(&self, range: Range<usize>) -> Stack<'a> { fn slice(&self, range: Range<usize>) -> Stack<'a> {
Stack { Stack {
locals: &self.locals[range], locals: &self.locals[range],
opts: self.opts,
} }
} }
} }
@@ -1520,3 +1959,8 @@ struct VariantCase<'a> {
dst_i: u32, dst_i: u32,
dst_ty: &'a InterfaceType, dst_ty: &'a InterfaceType,
} }
enum MallocSize {
Const(usize),
Local(u32),
}

View File

@@ -29,6 +29,7 @@ pub enum Trap {
UnalignedPointer, UnalignedPointer,
InvalidDiscriminant, InvalidDiscriminant,
InvalidChar, InvalidChar,
ListByteLengthOverflow,
AssertFailed(&'static str), AssertFailed(&'static str),
} }
@@ -103,6 +104,7 @@ impl fmt::Display for Trap {
Trap::UnalignedPointer => "pointer not aligned correctly".fmt(f), Trap::UnalignedPointer => "pointer not aligned correctly".fmt(f),
Trap::InvalidDiscriminant => "invalid variant discriminant".fmt(f), Trap::InvalidDiscriminant => "invalid variant discriminant".fmt(f),
Trap::InvalidChar => "invalid char value specified".fmt(f), Trap::InvalidChar => "invalid char value specified".fmt(f),
Trap::ListByteLengthOverflow => "byte size of list too large for i32".fmt(f),
Trap::AssertFailed(s) => write!(f, "assertion failure: {}", s), Trap::AssertFailed(s) => write!(f, "assertion failure: {}", s),
} }
} }

View File

@@ -102,9 +102,13 @@ impl Options {
let memory = self.memory_mut(store.0); let memory = self.memory_mut(store.0);
let result_slice = match memory.get_mut(result..).and_then(|s| s.get_mut(..new_size)) { let result_slice = if new_size == 0 {
Some(end) => end, &mut []
None => bail!("realloc return: beyond end of memory"), } else {
match memory.get_mut(result..).and_then(|s| s.get_mut(..new_size)) {
Some(end) => end,
None => bail!("realloc return: beyond end of memory"),
}
}; };
Ok((result_slice, result)) Ok((result_slice, result))

View File

@@ -846,22 +846,26 @@ fn lower_string<T>(mem: &mut MemoryMut<'_, T>, string: &str) -> Result<(usize, u
match mem.string_encoding() { match mem.string_encoding() {
StringEncoding::Utf8 => { StringEncoding::Utf8 => {
let ptr = mem.realloc(0, 0, 1, string.len())?; let ptr = mem.realloc(0, 0, 1, string.len())?;
mem.as_slice_mut()[ptr..][..string.len()].copy_from_slice(string.as_bytes()); if string.len() > 0 {
mem.as_slice_mut()[ptr..][..string.len()].copy_from_slice(string.as_bytes());
}
Ok((ptr, string.len())) Ok((ptr, string.len()))
} }
StringEncoding::Utf16 => { StringEncoding::Utf16 => {
let size = string.len() * 2; let size = string.len() * 2;
let mut ptr = mem.realloc(0, 0, 2, size)?; let mut ptr = mem.realloc(0, 0, 2, size)?;
let bytes = &mut mem.as_slice_mut()[ptr..][..size];
let mut copied = 0; let mut copied = 0;
for (u, bytes) in string.encode_utf16().zip(bytes.chunks_mut(2)) { if size > 0 {
let u_bytes = u.to_le_bytes(); let bytes = &mut mem.as_slice_mut()[ptr..][..size];
bytes[0] = u_bytes[0]; for (u, bytes) in string.encode_utf16().zip(bytes.chunks_mut(2)) {
bytes[1] = u_bytes[1]; let u_bytes = u.to_le_bytes();
copied += 1; bytes[0] = u_bytes[0];
} bytes[1] = u_bytes[1];
if (copied * 2) < size { copied += 1;
ptr = mem.realloc(ptr, size, 2, copied * 2)?; }
if (copied * 2) < size {
ptr = mem.realloc(ptr, size, 2, copied * 2)?;
}
} }
Ok((ptr, copied)) Ok((ptr, copied))
} }

View File

@@ -1127,21 +1127,19 @@ fn some_traps() -> Result<()> {
err, err,
); );
} }
let err = instance(&mut store)? instance(&mut store)?
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")?
.call(&mut store, (&[],)) .call(&mut store, (&[],))
.unwrap_err(); .unwrap();
assert_oob(&err);
let err = instance(&mut store)? let err = instance(&mut store)?
.get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")? .get_typed_func::<(&[u8],), (), _>(&mut store, "take-list-base-oob")?
.call(&mut store, (&[1],)) .call(&mut store, (&[1],))
.unwrap_err(); .unwrap_err();
assert_oob(&err); assert_oob(&err);
let err = instance(&mut store)? instance(&mut store)?
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")?
.call(&mut store, ("",)) .call(&mut store, ("",))
.unwrap_err(); .unwrap();
assert_oob(&err);
let err = instance(&mut store)? let err = instance(&mut store)?
.get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")? .get_typed_func::<(&str,), (), _>(&mut store, "take-string-base-oob")?
.call(&mut store, ("x",)) .call(&mut store, ("x",))
@@ -1193,12 +1191,20 @@ fn some_traps() -> Result<()> {
// For this function the first allocation, the space to store all the // For this function the first allocation, the space to store all the
// arguments, is in-bounds but then all further allocations, such as for // arguments, is in-bounds but then all further allocations, such as for
// each individual string, are all out of bounds. // each individual string, are all out of bounds.
let err = instance(&mut store)? instance(&mut store)?
.get_typed_func::<(&str, &str, &str, &str, &str, &str, &str, &str, &str, &str), (), _>( .get_typed_func::<(&str, &str, &str, &str, &str, &str, &str, &str, &str, &str), (), _>(
&mut store, &mut store,
"take-many-second-oob", "take-many-second-oob",
)? )?
.call(&mut store, ("", "", "", "", "", "", "", "", "", "")) .call(&mut store, ("", "", "", "", "", "", "", "", "", ""))
.unwrap();
assert_oob(&err);
let err = instance(&mut store)?
.get_typed_func::<(&str, &str, &str, &str, &str, &str, &str, &str, &str, &str), (), _>(
&mut store,
"take-many-second-oob",
)?
.call(&mut store, ("", "", "", "", "", "", "", "", "", "x"))
.unwrap_err(); .unwrap_err();
assert_oob(&err); assert_oob(&err);
Ok(()) Ok(())

View File

@@ -111,3 +111,23 @@
)) ))
) )
"degenerate component adapter called") "degenerate component adapter called")
;; fiddling with 0-sized lists
(component $c
(core module $m
(func (export "x") (param i32 i32))
(func (export "realloc") (param i32 i32 i32 i32) (result i32)
i32.const -1)
(memory (export "memory") 0)
)
(core instance $m (instantiate $m))
(func $f (param (list unit))
(canon lift
(core func $m "x")
(realloc (func $m "realloc"))
(memory $m "memory")
)
)
(export "empty-list" (func $f))
)
(assert_return (invoke "empty-list" (list.const)) (unit.const))