From 0d1c4f32903192282da7c418f50b8076964a41fb Mon Sep 17 00:00:00 2001 From: Jef Date: Wed, 12 Dec 2018 11:52:48 +0100 Subject: [PATCH] Allow calling functions with any signature --- examples/test.rs | 2 +- src/module.rs | 39 +++++++++++++++++++++++++++++++++------ src/tests.rs | 47 ++++++++++++++++++++++++++++++++++++++++------- 3 files changed, 74 insertions(+), 14 deletions(-) diff --git a/examples/test.rs b/examples/test.rs index 3295df8ae6..163dda0142 100644 --- a/examples/test.rs +++ b/examples/test.rs @@ -22,7 +22,7 @@ fn read_to_end>(path: P) -> io::Result> { fn maybe_main() -> Result<(), String> { let data = read_to_end("test.wasm").map_err(|e| e.to_string())?; let translated = translate(&data).map_err(|e| e.to_string())?; - let result = translated.execute_func(0, 5, 3); + let result: u32 = unsafe { translated.execute_func(0, (5u32, 3u32)) }; println!("f(5, 3) = {}", result); Ok(()) diff --git a/src/module.rs b/src/module.rs index 52f0d61009..a38aac0940 100644 --- a/src/module.rs +++ b/src/module.rs @@ -4,6 +4,37 @@ use std::mem; use translate_sections; use wasmparser::{FuncType, ModuleReader, SectionCode}; +pub trait FunctionArgs { + unsafe fn call(self, start: *const u8) -> T; +} + +macro_rules! impl_function_args { + ($first:ident $(, $rest:ident)*) => { + impl<$first, $($rest),*> FunctionArgs for ($first, $($rest),*) { + #[allow(non_snake_case)] + unsafe fn call(self, start: *const u8) -> T { + let func = mem::transmute::<_, extern "sysv64" fn($first, $($rest),*) -> T>(start); + { + let ($first, $($rest),*) = self; + func($first, $($rest),*) + } + } + } + + impl_function_args!($($rest),*); + }; + () => { + impl FunctionArgs for () { + unsafe fn call(self, start: *const u8) -> T { + let func = mem::transmute::<_, extern "sysv64" fn() -> T>(start); + func() + } + } + }; +} + +impl_function_args!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S); + #[derive(Default)] pub struct TranslatedModule { translated_code_section: Option, @@ -11,19 +42,15 @@ pub struct TranslatedModule { impl TranslatedModule { // For testing only. - // Assume signature is (i32, i32) -> i32 for now. // TODO: Handle generic signatures. - pub fn execute_func(&self, func_idx: u32, a: usize, b: usize) -> usize { + pub unsafe fn execute_func(&self, func_idx: u32, args: Args) -> T { let code_section = self .translated_code_section .as_ref() .expect("no code section"); let start_buf = code_section.func_start(func_idx as usize); - unsafe { - let func = mem::transmute::<_, extern "sysv64" fn(usize, usize) -> usize>(start_buf); - func(a, b) - } + args.call(start_buf) } } diff --git a/src/tests.rs b/src/tests.rs index add6ee96b7..0e434c7233 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -8,9 +8,9 @@ fn translate_wat(wat: &str) -> TranslatedModule { } /// Execute the first function in the module. -fn execute_wat(wat: &str, a: usize, b: usize) -> usize { +fn execute_wat(wat: &str, a: u32, b: u32) -> u32 { let translated = translate_wat(wat); - translated.execute_func(0, a, b) + unsafe { translated.execute_func(0, (a, b)) } } #[test] @@ -20,7 +20,7 @@ fn empty() { #[test] fn adds() { - const CASES: &[(usize, usize, usize)] = &[(5, 3, 8), (0, 228, 228), (usize::max_value(), 1, 0)]; + const CASES: &[(u32, u32, u32)] = &[(5, 3, 8), (0, 228, 228), (u32::max_value(), 1, 0)]; let code = r#" (module @@ -34,7 +34,7 @@ fn adds() { #[test] fn relop_eq() { - const CASES: &[(usize, usize, usize)] = &[ + const CASES: &[(u32, u32, u32)] = &[ (0, 0, 1), (0, 1, 0), (1, 0, 0), @@ -56,7 +56,7 @@ fn relop_eq() { #[test] fn if_then_else() { - const CASES: &[(usize, usize, usize)] = &[ + const CASES: &[(u32, u32, u32)] = &[ (0, 1, 1), (0, 0, 0), (1, 0, 0), @@ -129,6 +129,39 @@ fn function_call() { assert_eq!(execute_wat(code, 2, 0), 2); } +#[test] +fn large_function_call() { + let code = r#" +(module + (func (param i32) (param i32) (param i32) (param i32) + (param i32) (param i32) + (result i32) + + (call $assert_zero + (get_local 5) + ) + (get_local 0) + ) + + (func $assert_zero (param $v i32) + (local i32) + (if (get_local $v) + (unreachable) + ) + ) +) + "#; + + assert_eq!( + { + let translated = translate_wat(code); + let out: u32 = unsafe { translated.execute_func(0, (5, 4, 3, 2, 1, 0)) }; + out + }, + 5 + ); +} + #[test] fn literals() { let code = r#" @@ -192,10 +225,10 @@ fn fib() { "#; // fac(x) = y <=> (x, y) - const FIB_SEQ: &[usize] = &[1, 1, 2, 3, 5, 8, 13, 21, 34, 55]; + const FIB_SEQ: &[u32] = &[1, 1, 2, 3, 5, 8, 13, 21, 34, 55]; for x in 0..10 { - assert_eq!(execute_wat(code, x, 0), FIB_SEQ[x]); + assert_eq!(execute_wat(code, x, 0), FIB_SEQ[x as usize]); } }