From c816a52746c7b31cae9b7f6a27a63879e9010a1f Mon Sep 17 00:00:00 2001 From: Alex Crichton Date: Mon, 8 Aug 2022 16:18:04 -0500 Subject: [PATCH] Reuse locals in adapter trampolines (#4646) This commit implements a scheme I've been meaning to work on in the adapter compiler where instead of always generating a fresh local for all operations locals may now be reused. Locals generated are explicitly free'd when their lexical scope has ended, allowing reuse in translation of later types in the adapter. This also implements a new scheme for initializing locals where previously a local could simply be generated, but now the local must be fused with its initializer where a `local.{tee,set}` instruction is always generated. This should help prevent a bug I ran into with strings where one usage of a local was forgotten to be initialized which meant that when it was used during a loop it may have had a stale value from before. Modeling this in Rust isn't possible at compile time unfortunately so I opted for the next best thing, runtime panics. If a local is accidentally not released back to the pool of free locals then it will panic. The fuzzer for simply generating and validating adapter modules should be good at exercising this and it weeded out a few forgotten free's and should be good now. --- crates/environ/src/fact/trampoline.rs | 663 +++++++++++++++----------- 1 file changed, 373 insertions(+), 290 deletions(-) diff --git a/crates/environ/src/fact/trampoline.rs b/crates/environ/src/fact/trampoline.rs index bc9095129f..fafbcc54e7 100644 --- a/crates/environ/src/fact/trampoline.rs +++ b/crates/environ/src/fact/trampoline.rs @@ -45,6 +45,9 @@ struct Compiler<'a, 'b> { /// Total number of locals generated so far. nlocals: u32, + /// Locals partitioned by type which are not currently in use. + free_locals: HashMap>, + /// Metadata about all `unreachable` trap instructions in this function and /// what the trap represents. The offset within `self.code` is recorded as /// well. @@ -71,6 +74,7 @@ pub(super) fn compile(module: &mut Module<'_>, adapter: &AdapterData) { module, code: Vec::new(), nlocals: lower_sig.params.len() as u32, + free_locals: HashMap::new(), traps: Vec::new(), result, top_level_translate: true, @@ -111,6 +115,7 @@ fn compile_translate_mem( module, code: Vec::new(), nlocals: 2, + free_locals: HashMap::new(), traps: Vec::new(), result, top_level_translate: true, @@ -122,13 +127,13 @@ fn compile_translate_mem( &src, &Source::Memory(Memory { opts: src_opts, - addr_local: 0, + addr: TempLocal::new(0, src_opts.ptr()), offset: 0, }), &dst, &Destination::Memory(Memory { opts: dst_opts, - addr_local: 1, + addr: TempLocal::new(1, dst_opts.ptr()), offset: 0, }), ); @@ -181,7 +186,7 @@ struct Memory<'a> { opts: &'a Options, /// The index of the local that contains the base address of where the /// storage is happening. - addr_local: u32, + addr: TempLocal, /// A "static" offset that will be baked into wasm instructions for where /// memory loads/stores happen. offset: u32, @@ -238,10 +243,11 @@ impl Compiler<'_, '_> { // into locals for result translation afterwards. self.instruction(Call(adapter.callee.as_u32())); let mut result_locals = Vec::with_capacity(lift_sig.results.len()); + let mut temps = Vec::new(); for ty in lift_sig.results.iter().rev() { - let local = self.gen_local(*ty); - self.instruction(LocalSet(local)); - result_locals.push((local, *ty)); + let local = self.local_set_new_tmp(*ty); + result_locals.push((local.idx, *ty)); + temps.push(local); } result_locals.reverse(); @@ -269,6 +275,10 @@ impl Compiler<'_, '_> { self.set_flag(adapter.lift.flags, FLAG_MAY_ENTER, true); } + for tmp in temps { + self.free_temp_local(tmp); + } + self.finish() } @@ -304,7 +314,7 @@ impl Compiler<'_, '_> { .map(|t| self.types.align(lower_opts, t)) .max() .unwrap_or(1); - Source::Memory(self.memory_operand(lower_opts, addr, align)) + Source::Memory(self.memory_operand(lower_opts, TempLocal::new(addr, ty), align)) }; let dst = if dst_flat.len() <= MAX_FLAT_PARAMS { @@ -331,7 +341,8 @@ impl Compiler<'_, '_> { // actual parameter that we're passing is the address of the values // stored, so ensure that's happening in the wasm body here. if let Destination::Memory(mem) = dst { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); + self.free_temp_local(mem.addr); } } @@ -363,7 +374,7 @@ impl Compiler<'_, '_> { assert_eq!(result_locals.len(), 1); let (addr, ty) = result_locals[0]; assert_eq!(ty, lift_opts.ptr()); - Source::Memory(self.memory_operand(lift_opts, addr, align)) + Source::Memory(self.memory_operand(lift_opts, TempLocal::new(addr, ty), align)) }; let dst = if dst_flat.len() <= MAX_FLAT_RESULTS { @@ -375,7 +386,7 @@ impl Compiler<'_, '_> { let align = self.types.align(lower_opts, &dst_ty); let (addr, ty) = *param_locals.last().expect("no retptr"); assert_eq!(ty, lower_opts.ptr()); - Destination::Memory(self.memory_operand(lower_opts, addr, align)) + Destination::Memory(self.memory_operand(lower_opts, TempLocal::new(addr, ty), align)) }; self.translate(&src_ty, &src, &dst_ty, &dst); @@ -479,12 +490,12 @@ impl Compiler<'_, '_> { compile_translate_mem(self.module, *src_ty, src.opts, *dst_ty, dst.opts); // TODO: overflow checks? - self.instruction(LocalGet(src.addr_local)); + self.instruction(LocalGet(src.addr.idx)); if src.offset != 0 { self.ptr_uconst(src.opts, src.offset); self.ptr_add(src.opts); } - self.instruction(LocalGet(dst.addr_local)); + self.instruction(LocalGet(dst.addr.idx)); if dst.offset != 0 { self.ptr_uconst(dst.opts, dst.offset); self.ptr_add(dst.opts); @@ -735,12 +746,11 @@ impl Compiler<'_, '_> { fn translate_char(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { assert!(matches!(dst_ty, InterfaceType::Char)); - let local = self.gen_local(ValType::I32); match src { Source::Memory(mem) => self.i32_load(mem), Source::Stack(stack) => self.stack_get(stack, ValType::I32), } - self.instruction(LocalSet(local)); + let local = self.local_set_new_tmp(ValType::I32); // This sequence is copied from the output of LLVM for: // @@ -759,7 +769,7 @@ impl Compiler<'_, '_> { // ... but I don't know how it works other than "well I trust LLVM" self.instruction(Block(BlockType::Empty)); self.instruction(Block(BlockType::Empty)); - self.instruction(LocalGet(local)); + self.instruction(LocalGet(local.idx)); self.instruction(I32Const(0xd800)); self.instruction(I32Xor); self.instruction(I32Const(-0x110000)); @@ -767,7 +777,7 @@ impl Compiler<'_, '_> { self.instruction(I32Const(-0x10f800)); self.instruction(I32LtU); self.instruction(BrIf(0)); - self.instruction(LocalGet(local)); + self.instruction(LocalGet(local.idx)); self.instruction(I32Const(0x110000)); self.instruction(I32Ne); self.instruction(BrIf(1)); @@ -776,13 +786,15 @@ impl Compiler<'_, '_> { self.instruction(End); self.push_dst_addr(dst); - self.instruction(LocalGet(local)); + self.instruction(LocalGet(local.idx)); match dst { Destination::Memory(mem) => { self.i32_store(mem); } Destination::Stack(stack, _) => self.stack_set(stack, ValType::I32), } + + self.free_temp_local(local); } fn translate_string(&mut self, src: &Source<'_>, dst_ty: &InterfaceType, dst: &Destination) { @@ -794,24 +806,20 @@ impl Compiler<'_, '_> { // will be referenced a good deal so this just makes it easier to deal // with them consistently below rather than trying to reload from memory // for example. - let src_ptr = self.gen_local(src_opts.ptr()); - let src_len = self.gen_local(src_opts.ptr()); match src { Source::Stack(s) => { assert_eq!(s.locals.len(), 2); self.stack_get(&s.slice(0..1), src_opts.ptr()); - self.instruction(LocalSet(src_ptr)); self.stack_get(&s.slice(1..2), src_opts.ptr()); - self.instruction(LocalSet(src_len)); } Source::Memory(mem) => { self.ptr_load(mem); - self.instruction(LocalSet(src_ptr)); self.ptr_load(&mem.bump(src_opts.ptr_size().into())); - self.instruction(LocalSet(src_len)); } } - let src_str = &WasmString { + let src_len = self.local_set_new_tmp(src_opts.ptr()); + let src_ptr = self.local_set_new_tmp(src_opts.ptr()); + let src_str = WasmString { ptr: src_ptr, len: src_len, opts: src_opts, @@ -819,32 +827,34 @@ impl Compiler<'_, '_> { let dst_str = match src_opts.string_encoding { StringEncoding::Utf8 => match dst_opts.string_encoding { - StringEncoding::Utf8 => self.string_copy(src_str, FE::Utf8, dst_opts, FE::Utf8), - StringEncoding::Utf16 => self.string_utf8_to_utf16(src_str, dst_opts), - StringEncoding::CompactUtf16 => self.string_to_compact(src_str, FE::Utf8, dst_opts), + StringEncoding::Utf8 => self.string_copy(&src_str, FE::Utf8, dst_opts, FE::Utf8), + StringEncoding::Utf16 => self.string_utf8_to_utf16(&src_str, dst_opts), + StringEncoding::CompactUtf16 => { + self.string_to_compact(&src_str, FE::Utf8, dst_opts) + } }, StringEncoding::Utf16 => { - self.verify_aligned(src_opts, src_ptr, 2); + self.verify_aligned(src_opts, src_str.ptr.idx, 2); match dst_opts.string_encoding { StringEncoding::Utf8 => { - self.string_deflate_to_utf8(src_str, FE::Utf16, dst_opts) + self.string_deflate_to_utf8(&src_str, FE::Utf16, dst_opts) } StringEncoding::Utf16 => { - self.string_copy(src_str, FE::Utf16, dst_opts, FE::Utf16) + self.string_copy(&src_str, FE::Utf16, dst_opts, FE::Utf16) } StringEncoding::CompactUtf16 => { - self.string_to_compact(src_str, FE::Utf16, dst_opts) + self.string_to_compact(&src_str, FE::Utf16, dst_opts) } } } StringEncoding::CompactUtf16 => { - self.verify_aligned(src_opts, src_ptr, 2); + self.verify_aligned(src_opts, src_str.ptr.idx, 2); // Test the tag big to see if this is a utf16 or a latin1 string // at runtime... - self.instruction(LocalGet(src_len)); + self.instruction(LocalGet(src_str.len.idx)); self.ptr_uconst(src_opts, UTF16_TAG); self.ptr_and(src_opts); self.ptr_if(src_opts, BlockType::Empty); @@ -852,19 +862,19 @@ impl Compiler<'_, '_> { // In the utf16 block unset the upper bit from the length local // so further calculations have the right value. Afterwards the // string transcode proceeds assuming utf16. - self.instruction(LocalGet(src_len)); + self.instruction(LocalGet(src_str.len.idx)); self.ptr_uconst(src_opts, UTF16_TAG); self.ptr_xor(src_opts); - self.instruction(LocalSet(src_len)); + self.instruction(LocalSet(src_str.len.idx)); let s1 = match dst_opts.string_encoding { StringEncoding::Utf8 => { - self.string_deflate_to_utf8(src_str, FE::Utf16, dst_opts) + self.string_deflate_to_utf8(&src_str, FE::Utf16, dst_opts) } StringEncoding::Utf16 => { - self.string_copy(src_str, FE::Utf16, dst_opts, FE::Utf16) + self.string_copy(&src_str, FE::Utf16, dst_opts, FE::Utf16) } StringEncoding::CompactUtf16 => { - self.string_compact_utf16_to_compact(src_str, dst_opts) + self.string_compact_utf16_to_compact(&src_str, dst_opts) } }; @@ -875,22 +885,24 @@ impl Compiler<'_, '_> { // happen. let s2 = match dst_opts.string_encoding { StringEncoding::Utf16 => { - self.string_copy(src_str, FE::Latin1, dst_opts, FE::Utf16) + self.string_copy(&src_str, FE::Latin1, dst_opts, FE::Utf16) } StringEncoding::Utf8 => { - self.string_deflate_to_utf8(src_str, FE::Latin1, dst_opts) + self.string_deflate_to_utf8(&src_str, FE::Latin1, dst_opts) } StringEncoding::CompactUtf16 => { - self.string_copy(src_str, FE::Latin1, dst_opts, FE::Latin1) + self.string_copy(&src_str, FE::Latin1, dst_opts, FE::Latin1) } }; // Set our `s2` generated locals to the `s2` generated locals // as the resulting pointer of this transcode. - self.instruction(LocalGet(s2.ptr)); - self.instruction(LocalSet(s1.ptr)); - self.instruction(LocalGet(s2.len)); - self.instruction(LocalSet(s1.len)); + self.instruction(LocalGet(s2.ptr.idx)); + self.instruction(LocalSet(s1.ptr.idx)); + self.instruction(LocalGet(s2.len.idx)); + self.instruction(LocalSet(s1.len.idx)); self.instruction(End); + self.free_temp_local(s2.ptr); + self.free_temp_local(s2.len); s1 } }; @@ -898,20 +910,25 @@ impl Compiler<'_, '_> { // Store the ptr/length in the desired destination match dst { Destination::Stack(s, _) => { - self.instruction(LocalGet(dst_str.ptr)); + self.instruction(LocalGet(dst_str.ptr.idx)); self.stack_set(&s[..1], dst_opts.ptr()); - self.instruction(LocalGet(dst_str.len)); + self.instruction(LocalGet(dst_str.len.idx)); self.stack_set(&s[1..], dst_opts.ptr()); } Destination::Memory(mem) => { - self.instruction(LocalGet(mem.addr_local)); - self.instruction(LocalGet(dst_str.ptr)); + self.instruction(LocalGet(mem.addr.idx)); + self.instruction(LocalGet(dst_str.ptr.idx)); self.ptr_store(mem); - self.instruction(LocalGet(mem.addr_local)); - self.instruction(LocalGet(dst_str.len)); + self.instruction(LocalGet(mem.addr.idx)); + self.instruction(LocalGet(dst_str.len.idx)); self.ptr_store(&mem.bump(dst_opts.ptr_size().into())); } } + + self.free_temp_local(src_str.ptr); + self.free_temp_local(src_str.len); + self.free_temp_local(dst_str.ptr); + self.free_temp_local(dst_str.len); } // Corresponding function for `store_string_copy` in the spec. @@ -939,41 +956,41 @@ impl Compiler<'_, '_> { // Calculate the source byte length given the size of each code // unit. Note that this shouldn't overflow given // `validate_string_length` above. + let mut src_byte_len_tmp = None; let src_byte_len = if src_enc.width() == 1 { - src.len + src.len.idx } else { assert_eq!(src_enc.width(), 2); - let tmp = self.gen_local(src.opts.ptr()); - self.instruction(LocalGet(src.len)); + self.instruction(LocalGet(src.len.idx)); self.ptr_uconst(src.opts, 1); self.ptr_shl(src.opts); - self.instruction(LocalSet(tmp)); - tmp + let tmp = self.local_set_new_tmp(src.opts.ptr()); + let ret = tmp.idx; + src_byte_len_tmp = Some(tmp); + ret }; // Convert the source code units length to the destination byte // length type. - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst_opts.ptr()); - let dst_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalTee(dst_len)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst_opts.ptr()); + let dst_len = self.local_tee_new_tmp(dst_opts.ptr()); if dst_enc.width() > 1 { assert_eq!(dst_enc.width(), 2); self.ptr_uconst(dst_opts, 1); self.ptr_shl(dst_opts); } - let dst_byte_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalSet(dst_byte_len)); + let dst_byte_len = self.local_set_new_tmp(dst_opts.ptr()); // Allocate space in the destination using the calculated byte // length. let dst = { let dst_mem = self.malloc( dst_opts, - MallocSize::Local(dst_byte_len), + MallocSize::Local(dst_byte_len.idx), dst_enc.width().into(), ); WasmString { - ptr: dst_mem.addr_local, + ptr: dst_mem.addr, len: dst_len, opts: dst_opts, } @@ -984,7 +1001,7 @@ impl Compiler<'_, '_> { // is done by loading the last byte of the string and if that // doesn't trap then it's known valid. self.validate_string_inbounds(src, src_byte_len); - self.validate_string_inbounds(&dst, dst_byte_len); + self.validate_string_inbounds(&dst, dst_byte_len.idx); // If the validations pass then the host `transcode` intrinsic // is invoked. This will either raise a trap or otherwise succeed @@ -997,11 +1014,16 @@ impl Compiler<'_, '_> { Transcode::Latin1ToUtf16 }; let transcode = self.transcoder(src, &dst, op); - self.instruction(LocalGet(src.ptr)); - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(dst.ptr)); + self.instruction(LocalGet(src.ptr.idx)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(dst.ptr.idx)); self.instruction(Call(transcode.as_u32())); + self.free_temp_local(dst_byte_len); + if let Some(tmp) = src_byte_len_tmp { + self.free_temp_local(tmp); + } + dst } // Corresponding function for `store_string_to_utf8` in the spec. @@ -1027,36 +1049,36 @@ impl Compiler<'_, '_> { // Optimistically assume that the code unit length of the source is // all that's needed in the destination. Perform that allocaiton // here and proceed to transcoding below. - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst_opts.ptr()); - let dst_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalTee(dst_len)); - let dst_byte_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalSet(dst_byte_len)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst_opts.ptr()); + let dst_len = self.local_tee_new_tmp(dst_opts.ptr()); + let dst_byte_len = self.local_set_new_tmp(dst_opts.ptr()); let dst = { - let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len), 1); + let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len.idx), 1); WasmString { - ptr: dst_mem.addr_local, + ptr: dst_mem.addr, len: dst_len, opts: dst_opts, } }; // Ensure buffers are all in-bounds + let mut src_byte_len_tmp = None; let src_byte_len = match src_enc { - FE::Latin1 => src.len, + FE::Latin1 => src.len.idx, FE::Utf16 => { - let tmp = self.gen_local(src.opts.ptr()); - self.instruction(LocalGet(src.len)); + self.instruction(LocalGet(src.len.idx)); self.ptr_uconst(src.opts, 1); self.ptr_shl(src.opts); - self.instruction(LocalSet(tmp)); - tmp + let tmp = self.local_set_new_tmp(src.opts.ptr()); + let ret = tmp.idx; + src_byte_len_tmp = Some(tmp); + ret } FE::Utf8 => unreachable!(), }; self.validate_string_inbounds(src, src_byte_len); - self.validate_string_inbounds(&dst, dst_byte_len); + self.validate_string_inbounds(&dst, dst_byte_len.idx); // Perform the initial transcode let op = match src_enc { @@ -1065,28 +1087,27 @@ impl Compiler<'_, '_> { FE::Utf8 => unreachable!(), }; let transcode = self.transcoder(src, &dst, op); - self.instruction(LocalGet(src.ptr)); - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(dst.ptr)); - self.instruction(LocalGet(dst_byte_len)); + self.instruction(LocalGet(src.ptr.idx)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(dst.ptr.idx)); + self.instruction(LocalGet(dst_byte_len.idx)); self.instruction(Call(transcode.as_u32())); - self.instruction(LocalSet(dst.len)); - let src_len_tmp = self.gen_local(src.opts.ptr()); - self.instruction(LocalSet(src_len_tmp)); + self.instruction(LocalSet(dst.len.idx)); + let src_len_tmp = self.local_set_new_tmp(src.opts.ptr()); // Test if the source was entirely transcoded by comparing // `src_len_tmp`, the number of code units transcoded from the // source, with `src_len`, the original number of code units. - self.instruction(LocalGet(src_len_tmp)); - self.instruction(LocalGet(src.len)); + self.instruction(LocalGet(src_len_tmp.idx)); + self.instruction(LocalGet(src.len.idx)); self.ptr_ne(src.opts); self.instruction(If(BlockType::Empty)); // Here a worst-case reallocation is performed to grow `dst_mem`. // In-line a check is also performed that the worst-case byte size // fits within the maximum size of strings. - self.instruction(LocalGet(dst.ptr)); // old_ptr - self.instruction(LocalGet(dst_byte_len)); // old_size + self.instruction(LocalGet(dst.ptr.idx)); // old_ptr + self.instruction(LocalGet(dst_byte_len.idx)); // old_size self.ptr_uconst(dst.opts, 1); // align let factor = match src_enc { FE::Latin1 => 2, @@ -1094,50 +1115,50 @@ impl Compiler<'_, '_> { _ => unreachable!(), }; self.validate_string_length_u8(src, factor); - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst_opts.ptr()); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst_opts.ptr()); self.ptr_uconst(dst_opts, factor.into()); self.ptr_mul(dst_opts); - self.instruction(LocalTee(dst_byte_len)); + self.instruction(LocalTee(dst_byte_len.idx)); self.instruction(Call(dst_opts.realloc.unwrap().as_u32())); - self.instruction(LocalSet(dst.ptr)); + self.instruction(LocalSet(dst.ptr.idx)); // Verify that the destination is still in-bounds - self.validate_string_inbounds(&dst, dst_byte_len); + self.validate_string_inbounds(&dst, dst_byte_len.idx); // Perform another round of transcoding that should be guaranteed // to succeed. Note that all the parameters here are offset by the // results of the first transcoding to only perform the remaining // transcode on the final units. - self.instruction(LocalGet(src.ptr)); - self.instruction(LocalGet(src_len_tmp)); + self.instruction(LocalGet(src.ptr.idx)); + self.instruction(LocalGet(src_len_tmp.idx)); if let FE::Utf16 = src_enc { self.ptr_uconst(src.opts, 1); self.ptr_shl(src.opts); } self.ptr_add(src.opts); - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(src_len_tmp)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(src_len_tmp.idx)); self.ptr_sub(src.opts); - self.instruction(LocalGet(dst.ptr)); - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.ptr.idx)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_add(dst.opts); - self.instruction(LocalGet(dst_byte_len)); - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst_byte_len.idx)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_sub(dst.opts); self.instruction(Call(transcode.as_u32())); // Add the second result, the amount of destination units encoded, // to `dst_len` so it's an accurate reflection of the final size of // the destination buffer. - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_add(dst.opts); - self.instruction(LocalSet(dst.len)); + self.instruction(LocalSet(dst.len.idx)); // In debug mode verify the first result consumed the entire string, // otherwise simply discard it. if self.module.debug { - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(src_len_tmp)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(src_len_tmp.idx)); self.ptr_sub(src.opts); self.ptr_ne(src.opts); self.instruction(If(BlockType::Empty)); @@ -1148,16 +1169,16 @@ impl Compiler<'_, '_> { } // Perform a downsizing if the worst-case size was too large - self.instruction(LocalGet(dst.len)); - self.instruction(LocalGet(dst_byte_len)); + self.instruction(LocalGet(dst.len.idx)); + self.instruction(LocalGet(dst_byte_len.idx)); self.ptr_ne(dst.opts); self.instruction(If(BlockType::Empty)); - self.instruction(LocalGet(dst.ptr)); // old_ptr - self.instruction(LocalGet(dst_byte_len)); // old_size + self.instruction(LocalGet(dst.ptr.idx)); // old_ptr + self.instruction(LocalGet(dst_byte_len.idx)); // old_size self.ptr_uconst(dst.opts, 1); // align - self.instruction(LocalGet(dst.len)); // new_size + self.instruction(LocalGet(dst.len.idx)); // new_size self.instruction(Call(dst.opts.realloc.unwrap().as_u32())); - self.instruction(LocalSet(dst.ptr)); + self.instruction(LocalSet(dst.ptr.idx)); self.instruction(End); // If the first transcode was enough then assert that the returned @@ -1165,8 +1186,8 @@ impl Compiler<'_, '_> { if self.module.debug { self.instruction(Else); - self.instruction(LocalGet(dst.len)); - self.instruction(LocalGet(dst_byte_len)); + self.instruction(LocalGet(dst.len.idx)); + self.instruction(LocalGet(dst_byte_len.idx)); self.ptr_ne(dst_opts); self.instruction(If(BlockType::Empty)); self.trap(Trap::AssertFailed("should have finished encoding")); @@ -1175,6 +1196,12 @@ impl Compiler<'_, '_> { self.instruction(End); // end of "first transcode not enough" + self.free_temp_local(src_len_tmp); + self.free_temp_local(dst_byte_len); + if let Some(tmp) = src_byte_len_tmp { + self.free_temp_local(tmp); + } + dst } @@ -1198,31 +1225,29 @@ impl Compiler<'_, '_> { dst_opts: &'a Options, ) -> WasmString<'a> { self.validate_string_length(src, FE::Utf16); - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst_opts.ptr()); - let dst_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalTee(dst_len)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst_opts.ptr()); + let dst_len = self.local_tee_new_tmp(dst_opts.ptr()); self.ptr_uconst(dst_opts, 1); self.ptr_shl(dst_opts); - let dst_byte_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalSet(dst_byte_len)); + let dst_byte_len = self.local_set_new_tmp(dst_opts.ptr()); let dst = { - let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len), 2); + let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len.idx), 2); WasmString { - ptr: dst_mem.addr_local, + ptr: dst_mem.addr, len: dst_len, opts: dst_opts, } }; - self.validate_string_inbounds(src, src.len); - self.validate_string_inbounds(&dst, dst_byte_len); + self.validate_string_inbounds(src, src.len.idx); + self.validate_string_inbounds(&dst, dst_byte_len.idx); let transcode = self.transcoder(src, &dst, Transcode::Utf8ToUtf16); - self.instruction(LocalGet(src.ptr)); - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(dst.ptr)); + self.instruction(LocalGet(src.ptr.idx)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(dst.ptr.idx)); self.instruction(Call(transcode.as_u32())); - self.instruction(LocalSet(dst.len)); + self.instruction(LocalSet(dst.len.idx)); // If the number of code units returned by transcode is not // equal to the original number of code units then @@ -1231,20 +1256,22 @@ impl Compiler<'_, '_> { // Note that the byte length of the final allocation we // want is twice the code unit length returned by the // transcoding function. - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst.opts.ptr()); - self.instruction(LocalGet(dst.len)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst.opts.ptr()); + self.instruction(LocalGet(dst.len.idx)); self.ptr_ne(dst_opts); self.instruction(If(BlockType::Empty)); - self.instruction(LocalGet(dst.ptr)); - self.instruction(LocalGet(dst_byte_len)); + self.instruction(LocalGet(dst.ptr.idx)); + self.instruction(LocalGet(dst_byte_len.idx)); self.ptr_uconst(dst.opts, 2); - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_uconst(dst.opts, 1); self.ptr_shl(dst.opts); self.instruction(Call(dst.opts.realloc.unwrap().as_u32())); - self.instruction(LocalSet(dst.ptr)); + self.instruction(LocalSet(dst.ptr.idx)); self.instruction(End); // end of shrink-to-fit + self.free_temp_local(dst_byte_len); + dst } @@ -1267,43 +1294,40 @@ impl Compiler<'_, '_> { dst_opts: &'a Options, ) -> WasmString<'a> { self.validate_string_length(src, FE::Utf16); - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst_opts.ptr()); - let dst_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalTee(dst_len)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst_opts.ptr()); + let dst_len = self.local_tee_new_tmp(dst_opts.ptr()); self.ptr_uconst(dst_opts, 1); self.ptr_shl(dst_opts); - let dst_byte_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalSet(dst_byte_len)); + let dst_byte_len = self.local_set_new_tmp(dst_opts.ptr()); let dst = { - let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len), 2); + let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len.idx), 2); WasmString { - ptr: dst_mem.addr_local, + ptr: dst_mem.addr, len: dst_len, opts: dst_opts, } }; - let src_byte_len = self.gen_local(src.opts.ptr()); - self.convert_src_len_to_dst(dst_byte_len, dst.opts.ptr(), src.opts.ptr()); - self.instruction(LocalSet(src_byte_len)); + self.convert_src_len_to_dst(dst_byte_len.idx, dst.opts.ptr(), src.opts.ptr()); + let src_byte_len = self.local_set_new_tmp(src.opts.ptr()); - self.validate_string_inbounds(src, src.len); - self.validate_string_inbounds(&dst, dst_byte_len); + self.validate_string_inbounds(src, src_byte_len.idx); + self.validate_string_inbounds(&dst, dst_byte_len.idx); let transcode = self.transcoder(src, &dst, Transcode::Utf16ToCompactProbablyUtf16); - self.instruction(LocalGet(src.ptr)); - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(dst.ptr)); + self.instruction(LocalGet(src.ptr.idx)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(dst.ptr.idx)); self.instruction(Call(transcode.as_u32())); - self.instruction(LocalSet(dst.len)); + self.instruction(LocalSet(dst.len.idx)); // Assert that the untagged code unit length is the same as the // source code unit length. if self.module.debug { - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_uconst(dst.opts, !UTF16_TAG); self.ptr_and(dst.opts); - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst.opts.ptr()); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst.opts.ptr()); self.ptr_ne(dst.opts); self.instruction(If(BlockType::Empty)); self.trap(Trap::AssertFailed("expected equal code units")); @@ -1313,18 +1337,21 @@ impl Compiler<'_, '_> { // If the UTF16_TAG is set then utf16 was used and the destination // should be appropriately sized. Bail out of the "is this string // empty" block and fall through otherwise to resizing. - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_uconst(dst.opts, UTF16_TAG); self.ptr_and(dst.opts); self.ptr_br_if(dst.opts, 0); // Here `realloc` is used to downsize the string - self.instruction(LocalGet(dst.ptr)); // old_ptr - self.instruction(LocalGet(dst_byte_len)); // old_size + self.instruction(LocalGet(dst.ptr.idx)); // old_ptr + self.instruction(LocalGet(dst_byte_len.idx)); // old_size self.ptr_uconst(dst.opts, 2); // align - self.instruction(LocalGet(dst.len)); // new_size + self.instruction(LocalGet(dst.len.idx)); // new_size self.instruction(Call(dst.opts.realloc.unwrap().as_u32())); - self.instruction(LocalSet(dst.ptr)); + self.instruction(LocalSet(dst.ptr.idx)); + + self.free_temp_local(dst_byte_len); + self.free_temp_local(src_byte_len); dst } @@ -1342,22 +1369,20 @@ impl Compiler<'_, '_> { dst_opts: &'a Options, ) -> WasmString<'a> { self.validate_string_length(src, src_enc); - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst_opts.ptr()); - let dst_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalTee(dst_len)); - let dst_byte_len = self.gen_local(dst_opts.ptr()); - self.instruction(LocalSet(dst_byte_len)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst_opts.ptr()); + let dst_len = self.local_tee_new_tmp(dst_opts.ptr()); + let dst_byte_len = self.local_set_new_tmp(dst_opts.ptr()); let dst = { - let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len), 2); + let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len.idx), 2); WasmString { - ptr: dst_mem.addr_local, + ptr: dst_mem.addr, len: dst_len, opts: dst_opts, } }; - self.validate_string_inbounds(src, src.len); - self.validate_string_inbounds(&dst, dst_byte_len); + self.validate_string_inbounds(src, src.len.idx); + self.validate_string_inbounds(&dst, dst_byte_len.idx); // Perform the initial latin1 transcode. This returns the number of // source code units consumed and the number of destination code @@ -1369,34 +1394,33 @@ impl Compiler<'_, '_> { }; let transcode_latin1 = self.transcoder(src, &dst, latin1); let transcode_utf16 = self.transcoder(src, &dst, utf16); - self.instruction(LocalGet(src.ptr)); - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(dst.ptr)); + self.instruction(LocalGet(src.ptr.idx)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(dst.ptr.idx)); self.instruction(Call(transcode_latin1.as_u32())); - self.instruction(LocalSet(dst.len)); - let src_len_tmp = self.gen_local(src.opts.ptr()); - self.instruction(LocalSet(src_len_tmp)); + self.instruction(LocalSet(dst.len.idx)); + let src_len_tmp = self.local_set_new_tmp(src.opts.ptr()); // If the source was entirely consumed then the transcode completed // and all that's necessary is to optionally shrink the buffer. - self.instruction(LocalGet(src_len_tmp)); - self.instruction(LocalGet(src.len)); + self.instruction(LocalGet(src_len_tmp.idx)); + self.instruction(LocalGet(src.len.idx)); self.ptr_eq(src.opts); self.instruction(If(BlockType::Empty)); // if latin1-or-utf16 block // Test if the original byte length of the allocation is the same as // the number of written bytes, and if not then shrink the buffer // with a call to `realloc`. - self.instruction(LocalGet(dst_byte_len)); - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst_byte_len.idx)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_ne(dst.opts); self.instruction(If(BlockType::Empty)); - self.instruction(LocalGet(dst.ptr)); // old_ptr - self.instruction(LocalGet(dst_byte_len)); // old_size + self.instruction(LocalGet(dst.ptr.idx)); // old_ptr + self.instruction(LocalGet(dst_byte_len.idx)); // old_size self.ptr_uconst(dst.opts, 2); // align - self.instruction(LocalGet(dst.len)); // new_size + self.instruction(LocalGet(dst.len.idx)); // new_size self.instruction(Call(dst.opts.realloc.unwrap().as_u32())); - self.instruction(LocalSet(dst.ptr)); + self.instruction(LocalSet(dst.ptr.idx)); self.instruction(End); // In this block the latin1 encoding failed. The host transcode @@ -1413,34 +1437,34 @@ impl Compiler<'_, '_> { // Reallocate the buffer with twice the source code units in byte // size. - self.instruction(LocalGet(dst.ptr)); // old_ptr - self.instruction(LocalGet(dst_byte_len)); // old_size + self.instruction(LocalGet(dst.ptr.idx)); // old_ptr + self.instruction(LocalGet(dst_byte_len.idx)); // old_size self.ptr_uconst(dst.opts, 2); // align - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst.opts.ptr()); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst.opts.ptr()); self.ptr_uconst(dst.opts, 1); self.ptr_shl(dst.opts); - self.instruction(LocalTee(dst_byte_len)); + self.instruction(LocalTee(dst_byte_len.idx)); self.instruction(Call(dst.opts.realloc.unwrap().as_u32())); - self.instruction(LocalSet(dst.ptr)); + self.instruction(LocalSet(dst.ptr.idx)); // Call the host utf16 transcoding function. This will inflate the // prior latin1 bytes and then encode the rest of the source string // as utf16 into the remaining space in the destination buffer. - self.instruction(LocalGet(src.ptr)); - self.instruction(LocalGet(src_len_tmp)); + self.instruction(LocalGet(src.ptr.idx)); + self.instruction(LocalGet(src_len_tmp.idx)); if let FE::Utf16 = src_enc { self.ptr_uconst(src.opts, 1); self.ptr_shl(src.opts); } self.ptr_add(src.opts); - self.instruction(LocalGet(src.len)); - self.instruction(LocalGet(src_len_tmp)); + self.instruction(LocalGet(src.len.idx)); + self.instruction(LocalGet(src_len_tmp.idx)); self.ptr_sub(src.opts); - self.instruction(LocalGet(dst.ptr)); - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst.opts.ptr()); - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.ptr.idx)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst.opts.ptr()); + self.instruction(LocalGet(dst.len.idx)); self.instruction(Call(transcode_utf16.as_u32())); - self.instruction(LocalSet(dst.len)); + self.instruction(LocalSet(dst.len.idx)); // If the returned number of code units written to the destination // is not equal to the size of the allocation then the allocation is @@ -1449,28 +1473,31 @@ impl Compiler<'_, '_> { // Note that the byte size desired is `2*dst_len` and the current // byte buffer size is `2*src_len` so the `2` factor isn't checked // here, just the lengths. - self.instruction(LocalGet(dst.len)); - self.convert_src_len_to_dst(src.len, src.opts.ptr(), dst.opts.ptr()); + self.instruction(LocalGet(dst.len.idx)); + self.convert_src_len_to_dst(src.len.idx, src.opts.ptr(), dst.opts.ptr()); self.ptr_ne(dst.opts); self.instruction(If(BlockType::Empty)); - self.instruction(LocalGet(dst.ptr)); // old_ptr - self.instruction(LocalGet(dst_byte_len)); // old_size + self.instruction(LocalGet(dst.ptr.idx)); // old_ptr + self.instruction(LocalGet(dst_byte_len.idx)); // old_size self.ptr_uconst(dst.opts, 2); // align - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_uconst(dst.opts, 1); self.ptr_shl(dst.opts); self.instruction(Call(dst.opts.realloc.unwrap().as_u32())); - self.instruction(LocalSet(dst.ptr)); + self.instruction(LocalSet(dst.ptr.idx)); self.instruction(End); // Tag the returned pointer as utf16 - self.instruction(LocalGet(dst.len)); + self.instruction(LocalGet(dst.len.idx)); self.ptr_uconst(dst.opts, UTF16_TAG); self.ptr_or(dst.opts); - self.instruction(LocalSet(dst.len)); + self.instruction(LocalSet(dst.len.idx)); self.instruction(End); // end latin1-or-utf16 block + self.free_temp_local(src_len_tmp); + self.free_temp_local(dst_byte_len); + dst } @@ -1481,7 +1508,7 @@ impl Compiler<'_, '_> { fn validate_string_length_u8(&mut self, s: &WasmString<'_>, dst: u8) { // Check to see if the source byte length is out of bounds in // which case a trap is generated. - self.instruction(LocalGet(s.len)); + self.instruction(LocalGet(s.len.idx)); let max = MAX_STRING_BYTE_LENGTH / u32::from(dst); self.ptr_uconst(s.opts, max); self.ptr_ge_u(s.opts); @@ -1528,18 +1555,18 @@ impl Compiler<'_, '_> { // base pointer to the byte length. For 32-bit memories there's no need // to check for overflow since everything is extended to 64-bit, but for // 64-bit memories overflow is checked. - self.instruction(LocalGet(s.ptr)); + self.instruction(LocalGet(s.ptr.idx)); extend_to_64(self); self.instruction(LocalGet(byte_len)); extend_to_64(self); self.instruction(I64Add); if s.opts.memory64 { - let tmp = self.gen_local(ValType::I64); - self.instruction(LocalTee(tmp)); - self.instruction(LocalGet(s.ptr)); + let tmp = self.local_tee_new_tmp(ValType::I64); + self.instruction(LocalGet(s.ptr.idx)); self.ptr_lt_u(s.opts); self.instruction(BrIf(0)); - self.instruction(LocalGet(tmp)); + self.instruction(LocalGet(tmp.idx)); + self.free_temp_local(tmp); } // If the byte size of memory is greater than the final address of the @@ -1574,23 +1601,19 @@ impl Compiler<'_, '_> { // will be referenced a good deal so this just makes it easier to deal // with them consistently below rather than trying to reload from memory // for example. - let src_ptr = self.gen_local(src_opts.ptr()); - let src_len = self.gen_local(src_opts.ptr()); match src { Source::Stack(s) => { assert_eq!(s.locals.len(), 2); self.stack_get(&s.slice(0..1), src_opts.ptr()); - self.instruction(LocalSet(src_ptr)); self.stack_get(&s.slice(1..2), src_opts.ptr()); - self.instruction(LocalSet(src_len)); } Source::Memory(mem) => { self.ptr_load(mem); - self.instruction(LocalSet(src_ptr)); self.ptr_load(&mem.bump(src_opts.ptr_size().into())); - self.instruction(LocalSet(src_len)); } } + let src_len = self.local_set_new_tmp(src_opts.ptr()); + let src_ptr = self.local_set_new_tmp(src_opts.ptr()); // Create a `Memory` operand which will internally assert that the // `src_ptr` value is properly aligned. @@ -1603,20 +1626,14 @@ impl Compiler<'_, '_> { // dst_size` doesn't overflow 32-bits and will place the final result in // `dst_byte_len` where `dst_byte_len` has the appropriate type for the // destination. - let dst_byte_len = self.gen_local(dst_opts.ptr()); - self.calculate_dst_byte_len( - src_len, - dst_byte_len, - src_opts.ptr(), - dst_opts.ptr(), - dst_size, - ); + let dst_byte_len = + self.calculate_dst_byte_len(src_len.idx, src_opts.ptr(), dst_opts.ptr(), dst_size); // Here `realloc` is invoked (in a `malloc`-like fashion) to allocate // space for the list in the destination memory. This will also // internally insert checks that the returned pointer is aligned // correctly for the destination. - let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len), dst_align); + let dst_mem = self.malloc(dst_opts, MallocSize::Local(dst_byte_len.idx), dst_align); // At this point we have aligned pointers, a length, and a byte length // for the destination. The spec also requires this translation to @@ -1657,14 +1674,14 @@ impl Compiler<'_, '_> { // maximum for a 32-bit memory then this entire bounds-check here can be // skipped. if !src_opts.memory64 && src_size > 0 { - self.instruction(LocalGet(src_mem.addr_local)); + self.instruction(LocalGet(src_mem.addr.idx)); self.instruction(I64ExtendI32U); if src_size < dst_size { // If the source byte size is less than the destination size // then we can leverage the fact that `dst_byte_len` was already // calculated and didn't overflow so this is also guaranteed to // not overflow. - self.instruction(LocalGet(src_len)); + self.instruction(LocalGet(src_len.idx)); self.instruction(I64ExtendI32U); if src_size != 1 { self.instruction(I64Const(i64::try_from(src_size).unwrap())); @@ -1675,7 +1692,7 @@ impl Compiler<'_, '_> { // size then that can be reused. Note that the destination byte // size is already guaranteed to fit in 32 bits, even if it's // store in a 64-bit local. - self.instruction(LocalGet(dst_byte_len)); + self.instruction(LocalGet(dst_byte_len.idx)); if dst_opts.ptr() == ValType::I32 { self.instruction(I64ExtendI32U); } @@ -1703,7 +1720,7 @@ impl Compiler<'_, '_> { // stay set as part of this computation, so the multiplication // here is left unchecked to fall through into the addition // below. - self.instruction(LocalGet(src_len)); + self.instruction(LocalGet(src_len.idx)); self.instruction(I64ExtendI32U); self.instruction(I64Const(i64::try_from(src_size).unwrap())); self.instruction(I64Mul); @@ -1721,9 +1738,9 @@ impl Compiler<'_, '_> { // relatively simple since we've already calculated the byte length of // the destination above and can reuse that in this check. if !dst_opts.memory64 && dst_size > 0 { - self.instruction(LocalGet(dst_mem.addr_local)); + self.instruction(LocalGet(dst_mem.addr.idx)); self.instruction(I64ExtendI32U); - self.instruction(LocalGet(dst_byte_len)); + self.instruction(LocalGet(dst_byte_len.idx)); self.instruction(I64ExtendI32U); self.instruction(I64Add); self.instruction(I64Const(32)); @@ -1734,29 +1751,27 @@ impl Compiler<'_, '_> { self.instruction(End); } + self.free_temp_local(dst_byte_len); + // This is the main body of the loop to actually translate list types. // Note that if both element sizes are 0 then this won't actually do // anything so the loop is removed entirely. if src_size > 0 || dst_size > 0 { - let cur_dst_ptr = self.gen_local(dst_opts.ptr()); - let cur_src_ptr = self.gen_local(src_opts.ptr()); - let remaining = self.gen_local(src_opts.ptr()); - // This block encompasses the entire loop and is use to exit before even // entering the loop if the list size is zero. self.instruction(Block(BlockType::Empty)); // Set the `remaining` local and only continue if it's > 0 - self.instruction(LocalGet(src_len)); - self.instruction(LocalTee(remaining)); + self.instruction(LocalGet(src_len.idx)); + let remaining = self.local_tee_new_tmp(src_opts.ptr()); self.ptr_eqz(src_opts); self.instruction(BrIf(0)); // Initialize the two destination pointers to their initial values - self.instruction(LocalGet(src_mem.addr_local)); - self.instruction(LocalSet(cur_src_ptr)); - self.instruction(LocalGet(dst_mem.addr_local)); - self.instruction(LocalSet(cur_dst_ptr)); + self.instruction(LocalGet(src_mem.addr.idx)); + let cur_src_ptr = self.local_set_new_tmp(src_opts.ptr()); + self.instruction(LocalGet(dst_mem.addr.idx)); + let cur_dst_ptr = self.local_set_new_tmp(dst_opts.ptr()); self.instruction(Loop(BlockType::Empty)); @@ -1764,67 +1779,74 @@ impl Compiler<'_, '_> { let element_src = Source::Memory(Memory { opts: src_opts, offset: 0, - addr_local: cur_src_ptr, + addr: TempLocal::new(cur_src_ptr.idx, cur_src_ptr.ty), }); let element_dst = Destination::Memory(Memory { opts: dst_opts, offset: 0, - addr_local: cur_dst_ptr, + addr: TempLocal::new(cur_dst_ptr.idx, cur_dst_ptr.ty), }); self.translate(src_element_ty, &element_src, dst_element_ty, &element_dst); // Update the two loop pointers if src_size > 0 { - self.instruction(LocalGet(cur_src_ptr)); + self.instruction(LocalGet(cur_src_ptr.idx)); self.ptr_uconst(src_opts, u32::try_from(src_size).unwrap()); self.ptr_add(src_opts); - self.instruction(LocalSet(cur_src_ptr)); + self.instruction(LocalSet(cur_src_ptr.idx)); } if dst_size > 0 { - self.instruction(LocalGet(cur_dst_ptr)); + self.instruction(LocalGet(cur_dst_ptr.idx)); self.ptr_uconst(dst_opts, u32::try_from(dst_size).unwrap()); self.ptr_add(dst_opts); - self.instruction(LocalSet(cur_dst_ptr)); + self.instruction(LocalSet(cur_dst_ptr.idx)); } // Update the remaining count, falling through to break out if it's zero // now. - self.instruction(LocalGet(remaining)); + self.instruction(LocalGet(remaining.idx)); self.ptr_iconst(src_opts, -1); self.ptr_add(src_opts); - self.instruction(LocalTee(remaining)); + self.instruction(LocalTee(remaining.idx)); self.ptr_br_if(src_opts, 0); self.instruction(End); // end of loop self.instruction(End); // end of block + + self.free_temp_local(cur_dst_ptr); + self.free_temp_local(cur_src_ptr); + self.free_temp_local(remaining); } // Store the ptr/length in the desired destination match dst { Destination::Stack(s, _) => { - self.instruction(LocalGet(dst_mem.addr_local)); + self.instruction(LocalGet(dst_mem.addr.idx)); self.stack_set(&s[..1], dst_opts.ptr()); - self.convert_src_len_to_dst(src_len, src_opts.ptr(), dst_opts.ptr()); + self.convert_src_len_to_dst(src_len.idx, src_opts.ptr(), dst_opts.ptr()); self.stack_set(&s[1..], dst_opts.ptr()); } Destination::Memory(mem) => { - self.instruction(LocalGet(mem.addr_local)); - self.instruction(LocalGet(dst_mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); + self.instruction(LocalGet(dst_mem.addr.idx)); self.ptr_store(mem); - self.instruction(LocalGet(mem.addr_local)); - self.convert_src_len_to_dst(src_len, src_opts.ptr(), dst_opts.ptr()); + self.instruction(LocalGet(mem.addr.idx)); + self.convert_src_len_to_dst(src_len.idx, src_opts.ptr(), dst_opts.ptr()); self.ptr_store(&mem.bump(dst_opts.ptr_size().into())); } } + + self.free_temp_local(src_len); + self.free_temp_local(src_mem.addr); + self.free_temp_local(dst_mem.addr); } fn calculate_dst_byte_len( &mut self, src_len_local: u32, - dst_len_local: u32, src_ptr_ty: ValType, dst_ptr_ty: ValType, dst_elt_size: usize, - ) { + ) -> TempLocal { // Zero-size types are easy to handle here because the byte size of the // destination is always zero. if dst_elt_size == 0 { @@ -1833,8 +1855,7 @@ impl Compiler<'_, '_> { } else { self.instruction(I32Const(0)); } - self.instruction(LocalSet(dst_len_local)); - return; + return self.local_set_new_tmp(dst_ptr_ty); } // For one-byte elements in the destination the check here can be a bit @@ -1855,8 +1876,7 @@ impl Compiler<'_, '_> { self.instruction(End); } self.convert_src_len_to_dst(src_len_local, src_ptr_ty, dst_ptr_ty); - self.instruction(LocalSet(dst_len_local)); - return; + return self.local_set_new_tmp(dst_ptr_ty); } // The main check implemented by this function is to verify that @@ -1894,14 +1914,9 @@ impl Compiler<'_, '_> { // // The result of the multiplication is saved into a local as well to // get the result afterwards. - let tmp = if dst_ptr_ty != ValType::I64 { - self.gen_local(ValType::I64) - } else { - dst_len_local - }; self.instruction(I64Const(u32::try_from(dst_elt_size).unwrap().into())); self.instruction(I64Mul); - self.instruction(LocalTee(tmp)); + let tmp = self.local_tee_new_tmp(ValType::I64); // Branch to success if the upper 32-bits are zero, otherwise // fall-through to the trap. self.instruction(I64Const(32)); @@ -1915,10 +1930,13 @@ impl Compiler<'_, '_> { // If a fresh local was used to store the result of the multiplication // then convert it down to 32-bits which should be guaranteed to not // lose information at this point. - if dst_ptr_ty != ValType::I64 { - self.instruction(LocalGet(tmp)); + if dst_ptr_ty == ValType::I64 { + tmp + } else { + self.instruction(LocalGet(tmp.idx)); self.instruction(I32WrapI64); - self.instruction(LocalSet(dst_len_local)); + self.free_temp_local(tmp); + self.local_set_new_tmp(ValType::I32) } } @@ -2426,7 +2444,7 @@ impl Compiler<'_, '_> { return; } assert!(align.is_power_of_two()); - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.ptr_uconst(mem.opts, mem.offset); self.ptr_add(mem.opts); self.ptr_uconst(mem.opts, u32::try_from(align - 1).unwrap()); @@ -2437,7 +2455,6 @@ impl Compiler<'_, '_> { } fn malloc<'a>(&mut self, opts: &'a Options, size: MallocSize, align: usize) -> Memory<'a> { - let addr_local = self.gen_local(opts.ptr()); let realloc = opts.realloc.unwrap(); self.ptr_uconst(opts, 0); self.ptr_uconst(opts, 0); @@ -2447,35 +2464,77 @@ impl Compiler<'_, '_> { MallocSize::Local(idx) => self.instruction(LocalGet(idx)), } self.instruction(Call(realloc.as_u32())); - self.instruction(LocalSet(addr_local)); - self.memory_operand(opts, addr_local, align) + let addr = self.local_set_new_tmp(opts.ptr()); + self.memory_operand(opts, addr, align) } fn memory_operand<'a>( &mut self, opts: &'a Options, - addr_local: u32, + addr: TempLocal, align: usize, ) -> Memory<'a> { let ret = Memory { - addr_local, + addr, offset: 0, opts, }; - self.verify_aligned(opts, ret.addr_local, align); + self.verify_aligned(opts, ret.addr.idx, align); ret } - fn gen_local(&mut self, ty: ValType) -> u32 { - // TODO: see if local reuse is necessary, right now this always - // generates a new local. + /// Generates a new local in this function of the `ty` specified, + /// initializing it with the top value on the current wasm stack. + /// + /// The returned `TempLocal` must be freed after it is finished with + /// `free_temp_local`. + fn local_tee_new_tmp(&mut self, ty: ValType) -> TempLocal { + self.gen_temp_local(ty, LocalTee) + } + + /// Same as `local_tee_new_tmp` but initializes the local with `LocalSet` + /// instead of `LocalTee`. + fn local_set_new_tmp(&mut self, ty: ValType) -> TempLocal { + self.gen_temp_local(ty, LocalSet) + } + + fn gen_temp_local(&mut self, ty: ValType, insn: fn(u32) -> Instruction<'static>) -> TempLocal { + // First check to see if any locals are available in this function which + // were previously generated but are no longer in use. + if let Some(idx) = self.free_locals.get_mut(&ty).and_then(|v| v.pop()) { + self.instruction(insn(idx)); + return TempLocal { + ty, + idx, + needs_free: true, + }; + } + + // Failing that generate a fresh new local. let locals = &mut self.module.funcs[self.result].locals; match locals.last_mut() { Some((cnt, prev_ty)) if ty == *prev_ty => *cnt += 1, _ => locals.push((1, ty)), } self.nlocals += 1; - self.nlocals - 1 + let idx = self.nlocals - 1; + self.instruction(insn(idx)); + TempLocal { + ty, + idx, + needs_free: true, + } + } + + /// Used to release a `TempLocal` from a particular lexical scope to allow + /// its possible reuse in later scopes. + fn free_temp_local(&mut self, mut local: TempLocal) { + assert!(local.needs_free); + self.free_locals + .entry(local.ty) + .or_insert(Vec::new()) + .push(local.idx); + local.needs_free = false; } fn instruction(&mut self, instr: Instruction) { @@ -2604,32 +2663,32 @@ impl Compiler<'_, '_> { } fn i32_load8u(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(I32Load8_U(mem.memarg(0))); } fn i32_load8s(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(I32Load8_S(mem.memarg(0))); } fn i32_load16u(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(I32Load16_U(mem.memarg(1))); } fn i32_load16s(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(I32Load16_S(mem.memarg(1))); } fn i32_load(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(I32Load(mem.memarg(2))); } fn i64_load(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(I64Load(mem.memarg(3))); } @@ -2770,18 +2829,18 @@ impl Compiler<'_, '_> { } fn f32_load(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(F32Load(mem.memarg(2))); } fn f64_load(&mut self, mem: &Memory) { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); self.instruction(F64Load(mem.memarg(3))); } fn push_dst_addr(&mut self, dst: &Destination) { if let Destination::Memory(mem) = dst { - self.instruction(LocalGet(mem.addr_local)); + self.instruction(LocalGet(mem.addr.idx)); } } @@ -2972,7 +3031,7 @@ impl<'a> Memory<'a> { fn bump(&self, offset: usize) -> Memory<'a> { Memory { opts: self.opts, - addr_local: self.addr_local, + addr: TempLocal::new(self.addr.idx, self.addr.ty), offset: self.offset + u32::try_from(offset).unwrap(), } } @@ -3000,7 +3059,31 @@ enum MallocSize { } struct WasmString<'a> { - ptr: u32, - len: u32, + ptr: TempLocal, + len: TempLocal, opts: &'a Options, } + +struct TempLocal { + idx: u32, + ty: ValType, + needs_free: bool, +} + +impl TempLocal { + fn new(idx: u32, ty: ValType) -> TempLocal { + TempLocal { + idx, + ty, + needs_free: false, + } + } +} + +impl std::ops::Drop for TempLocal { + fn drop(&mut self) { + if self.needs_free { + panic!("temporary local not free'd"); + } + } +}