Make calling functions safe

This commit is contained in:
Jef
2019-01-14 18:45:14 +01:00
parent 1eebc65c9e
commit 8312730377
5 changed files with 201 additions and 95 deletions

View File

@@ -22,7 +22,7 @@ fn read_to_end<P: AsRef<Path>>(path: P) -> io::Result<Vec<u8>> {
fn maybe_main() -> Result<(), String> { fn maybe_main() -> Result<(), String> {
let data = read_to_end("test.wasm").map_err(|e| e.to_string())?; let data = read_to_end("test.wasm").map_err(|e| e.to_string())?;
let translated = translate(&data).map_err(|e| e.to_string())?; let translated = translate(&data).map_err(|e| e.to_string())?;
let result: u32 = unsafe { translated.execute_func(0, (5u32, 3u32)) }; let result: u32 = translated.execute_func(0, (5u32, 3u32)).unwrap();
println!("f(5, 3) = {}", result); println!("f(5, 3) = {}", result);
Ok(()) Ok(())

View File

@@ -1,6 +1,6 @@
use backend::*; use backend::*;
use error::Error; use error::Error;
use module::TranslationContext; use module::FuncTyStore;
use wasmparser::{FunctionBody, Operator, Type}; use wasmparser::{FunctionBody, Operator, Type};
// TODO: Use own declared `Type` enum. // TODO: Use own declared `Type` enum.
@@ -92,7 +92,7 @@ impl ControlFrame {
pub fn translate<T: Memory>( pub fn translate<T: Memory>(
session: &mut CodeGenSession<T>, session: &mut CodeGenSession<T>,
translation_ctx: &TranslationContext, translation_ctx: &FuncTyStore,
func_idx: u32, func_idx: u32,
body: &FunctionBody, body: &FunctionBody,
) -> Result<(), Error> ) -> Result<(), Error>

View File

@@ -1,8 +1,55 @@
use backend::TranslatedCodeSection; use backend::TranslatedCodeSection;
use error::Error; use error::Error;
use std::borrow::Cow;
use std::mem; use std::mem;
use translate_sections; use translate_sections;
use wasmparser::{FuncType, ModuleReader, SectionCode}; use wasmparser::{FuncType, ModuleReader, SectionCode, Type};
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Signature {
params: Cow<'static, [Type]>,
returns: Cow<'static, [Type]>,
}
impl PartialEq<FuncType> for Signature {
fn eq(&self, other: &FuncType) -> bool {
&self.params[..] == &other.params[..] && &self.returns[..] == &other.returns[..]
}
}
pub trait AsValueType {
const TYPE: Type;
}
pub trait TypeList {
const TYPE_LIST: &'static [Type];
}
impl<T> TypeList for T
where
T: AsValueType,
{
const TYPE_LIST: &'static [Type] = &[T::TYPE];
}
impl AsValueType for i32 {
const TYPE: Type = Type::I32;
}
impl AsValueType for i64 {
const TYPE: Type = Type::I64;
}
impl AsValueType for u32 {
const TYPE: Type = Type::I32;
}
impl AsValueType for u64 {
const TYPE: Type = Type::I64;
}
impl AsValueType for f32 {
const TYPE: Type = Type::F32;
}
impl AsValueType for f64 {
const TYPE: Type = Type::F64;
}
pub trait FunctionArgs { pub trait FunctionArgs {
unsafe fn call<T>(self, start: *const u8) -> T; unsafe fn call<T>(self, start: *const u8) -> T;
@@ -21,6 +68,10 @@ macro_rules! impl_function_args {
} }
} }
impl<$first: AsValueType, $($rest: AsValueType),*> TypeList for ($first, $($rest),*) {
const TYPE_LIST: &'static [Type] = &[$first::TYPE, $($rest::TYPE),*];
}
impl_function_args!($($rest),*); impl_function_args!($($rest),*);
}; };
() => { () => {
@@ -30,28 +81,57 @@ macro_rules! impl_function_args {
func() func()
} }
} }
impl TypeList for () {
const TYPE_LIST: &'static [Type] = &[];
}
}; };
} }
impl_function_args!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S); impl_function_args!(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S);
#[derive(Default)] #[derive(Default, Debug)]
pub struct TranslatedModule { pub struct TranslatedModule {
translated_code_section: Option<TranslatedCodeSection>, translated_code_section: Option<TranslatedCodeSection>,
types: FuncTyStore,
// Note: This vector should never be deallocated or reallocated or the pointer
// to its contents otherwise invalidated while the JIT'd code is still
// callable.
memory: Option<Vec<u8>>, memory: Option<Vec<u8>>,
} }
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum ExecutionError {
FuncIndexOutOfBounds,
TypeMismatch,
}
impl TranslatedModule { impl TranslatedModule {
// For testing only. // For testing only.
// TODO: Handle generic signatures. // TODO: Handle generic signatures.
pub unsafe fn execute_func<Args: FunctionArgs, T>(&self, func_idx: u32, args: Args) -> T { pub fn execute_func<Args: FunctionArgs + TypeList, T: TypeList>(
&self,
func_idx: u32,
args: Args,
) -> Result<T, ExecutionError> {
let code_section = self let code_section = self
.translated_code_section .translated_code_section
.as_ref() .as_ref()
.expect("no code section"); .expect("no code section");
if func_idx as usize >= self.types.func_ty_indicies.len() {
return Err(ExecutionError::FuncIndexOutOfBounds);
}
let type_ = self.types.func_type(func_idx);
if (&type_.params[..], &type_.returns[..]) != (Args::TYPE_LIST, T::TYPE_LIST) {
return Err(ExecutionError::TypeMismatch);
}
let start_buf = code_section.func_start(func_idx as usize); let start_buf = code_section.func_start(func_idx as usize);
args.call(start_buf) Ok(unsafe { args.call(start_buf) })
} }
pub fn disassemble(&self) { pub fn disassemble(&self) {
@@ -62,13 +142,13 @@ impl TranslatedModule {
} }
} }
#[derive(Default)] #[derive(Default, Debug)]
pub struct TranslationContext { pub struct FuncTyStore {
types: Vec<FuncType>, types: Vec<FuncType>,
func_ty_indicies: Vec<u32>, func_ty_indicies: Vec<u32>,
} }
impl TranslationContext { impl FuncTyStore {
pub fn func_type(&self, func_idx: u32) -> &FuncType { pub fn func_type(&self, func_idx: u32) -> &FuncType {
// TODO: This assumes that there is no imported functions. // TODO: This assumes that there is no imported functions.
let func_ty_idx = self.func_ty_indicies[func_idx as usize]; let func_ty_idx = self.func_ty_indicies[func_idx as usize];
@@ -89,11 +169,9 @@ pub fn translate(data: &[u8]) -> Result<TranslatedModule, Error> {
} }
let mut section = reader.read()?; let mut section = reader.read()?;
let mut ctx = TranslationContext::default();
if let SectionCode::Type = section.code { if let SectionCode::Type = section.code {
let types_reader = section.get_type_section_reader()?; let types_reader = section.get_type_section_reader()?;
ctx.types = translate_sections::type_(types_reader)?; output.types.types = translate_sections::type_(types_reader)?;
reader.skip_custom_sections()?; reader.skip_custom_sections()?;
if reader.eof() { if reader.eof() {
@@ -115,7 +193,7 @@ pub fn translate(data: &[u8]) -> Result<TranslatedModule, Error> {
if let SectionCode::Function = section.code { if let SectionCode::Function = section.code {
let functions = section.get_function_section_reader()?; let functions = section.get_function_section_reader()?;
ctx.func_ty_indicies = translate_sections::function(functions)?; output.types.func_ty_indicies = translate_sections::function(functions)?;
reader.skip_custom_sections()?; reader.skip_custom_sections()?;
if reader.eof() { if reader.eof() {
@@ -205,7 +283,7 @@ pub fn translate(data: &[u8]) -> Result<TranslatedModule, Error> {
let code = section.get_code_section_reader()?; let code = section.get_code_section_reader()?;
output.translated_code_section = Some(translate_sections::code( output.translated_code_section = Some(translate_sections::code(
code, code,
&ctx, &output.types,
output.memory.as_mut().map(|m| &mut m[..]), output.memory.as_mut().map(|m| &mut m[..]),
)?); )?);
@@ -221,5 +299,7 @@ pub fn translate(data: &[u8]) -> Result<TranslatedModule, Error> {
translate_sections::data(data)?; translate_sections::data(data)?;
} }
assert!(reader.eof());
Ok(output) Ok(output)
} }

View File

@@ -1,4 +1,4 @@
use super::{translate, TranslatedModule}; use super::{module::ExecutionError, translate, TranslatedModule};
use wabt; use wabt;
fn translate_wat(wat: &str) -> TranslatedModule { fn translate_wat(wat: &str) -> TranslatedModule {
@@ -10,7 +10,7 @@ fn translate_wat(wat: &str) -> TranslatedModule {
/// Execute the first function in the module. /// Execute the first function in the module.
fn execute_wat(wat: &str, a: u32, b: u32) -> u32 { fn execute_wat(wat: &str, a: u32, b: u32) -> u32 {
let translated = translate_wat(wat); let translated = translate_wat(wat);
unsafe { translated.execute_func(0, (a, b)) } translated.execute_func(0, (a, b)).unwrap()
} }
#[test] #[test]
@@ -30,15 +30,18 @@ mod op32 {
const OP: &str = stringify!($op); const OP: &str = stringify!($op);
lazy_static! { lazy_static! {
static ref AS_PARAMS: TranslatedModule = translate_wat(&format!(" static ref AS_PARAMS: TranslatedModule = translate_wat(&format!(
"
(module (func (param i32) (param i32) (result i32) (module (func (param i32) (param i32) (result i32)
(i32.{op} (get_local 0) (get_local 1)))) (i32.{op} (get_local 0) (get_local 1))))
", op = OP)); ",
op = OP
));
} }
quickcheck! { quickcheck! {
fn as_params(a: i32, b: i32) -> bool { fn as_params(a: i32, b: i32) -> bool {
unsafe { AS_PARAMS.execute_func::<(i32, i32), i32>(0, (a, b)) == $func(a, b) } AS_PARAMS.execute_func::<(i32, i32), i32>(0, (a, b)) == Ok($func(a, b))
} }
fn lit_lit(a: i32, b: i32) -> bool { fn lit_lit(a: i32, b: i32) -> bool {
@@ -48,9 +51,8 @@ mod op32 {
", op = OP, left = a, right = b)); ", op = OP, left = a, right = b));
static ONCE: Once = Once::new(); static ONCE: Once = Once::new();
ONCE.call_once(|| translated.disassemble()); ONCE.call_once(|| translated.disassemble());
unsafe {
translated.execute_func::<(), i32>(0, ()) == $func(a, b) translated.execute_func::<(), i32>(0, ()) == Ok($func(a, b))
}
} }
fn lit_reg(a: i32, b: i32) -> bool { fn lit_reg(a: i32, b: i32) -> bool {
@@ -60,9 +62,8 @@ mod op32 {
", op = OP, left = a)); ", op = OP, left = a));
static ONCE: Once = Once::new(); static ONCE: Once = Once::new();
ONCE.call_once(|| translated.disassemble()); ONCE.call_once(|| translated.disassemble());
unsafe {
translated.execute_func::<(i32,), i32>(0, (b,)) == $func(a, b) translated.execute_func::<(i32,), i32>(0, (b,)) == Ok($func(a, b))
}
} }
fn reg_lit(a: i32, b: i32) -> bool { fn reg_lit(a: i32, b: i32) -> bool {
@@ -72,9 +73,8 @@ mod op32 {
", op = OP, right = b)); ", op = OP, right = b));
static ONCE: Once = Once::new(); static ONCE: Once = Once::new();
ONCE.call_once(|| translated.disassemble()); ONCE.call_once(|| translated.disassemble());
unsafe {
translated.execute_func::<(i32,), i32>(0, (a,)) == $func(a, b) translated.execute_func::<(i32,), i32>(0, (a,)) == Ok($func(a, b))
}
} }
} }
} }
@@ -122,16 +122,14 @@ mod op64 {
quickcheck! { quickcheck! {
fn as_params(a: i64, b: i64) -> bool { fn as_params(a: i64, b: i64) -> bool {
unsafe { AS_PARAMS.execute_func::<(i64, i64), $retty>(0, (a, b)) == ($func(a, b) as $retty) } AS_PARAMS.execute_func::<(i64, i64), $retty>(0, (a, b)) == Ok($func(a, b) as $retty)
} }
fn lit_lit(a: i64, b: i64) -> bool { fn lit_lit(a: i64, b: i64) -> bool {
unsafe {
translate_wat(&format!(" translate_wat(&format!("
(module (func (result {retty}) (module (func (result {retty})
(i64.{op} (i64.const {left}) (i64.const {right})))) (i64.{op} (i64.const {left}) (i64.const {right}))))
", retty = RETTY, op = OP, left = a, right = b)).execute_func::<(), $retty>(0, ()) == ($func(a, b) as $retty) ", retty = RETTY, op = OP, left = a, right = b)).execute_func::<(), $retty>(0, ()) == Ok($func(a, b) as $retty)
}
} }
fn lit_reg(a: i64, b: i64) -> bool { fn lit_reg(a: i64, b: i64) -> bool {
@@ -143,18 +141,15 @@ mod op64 {
", retty = RETTY, op = OP, left = a)); ", retty = RETTY, op = OP, left = a));
static ONCE: Once = Once::new(); static ONCE: Once = Once::new();
ONCE.call_once(|| translated.disassemble()); ONCE.call_once(|| translated.disassemble());
unsafe {
translated.execute_func::<(i64,), $retty>(0, (b,)) == ($func(a, b) as $retty) translated.execute_func::<(i64,), $retty>(0, (b,)) == Ok($func(a, b) as $retty)
}
} }
fn reg_lit(a: i64, b: i64) -> bool { fn reg_lit(a: i64, b: i64) -> bool {
unsafe {
translate_wat(&format!(" translate_wat(&format!("
(module (func (param i64) (result {retty}) (module (func (param i64) (result {retty})
(i64.{op} (get_local 0) (i64.const {right})))) (i64.{op} (get_local 0) (i64.const {right}))))
", retty = RETTY, op = OP, right = b)).execute_func::<(i64,), $retty>(0, (a,)) == ($func(a, b) as $retty) ", retty = RETTY, op = OP, right = b)).execute_func::<(i64,), $retty>(0, (a,)) == Ok($func(a, b) as $retty)
}
} }
} }
} }
@@ -207,7 +202,7 @@ quickcheck! {
static ref TRANSLATED: TranslatedModule = translate_wat(CODE); static ref TRANSLATED: TranslatedModule = translate_wat(CODE);
} }
let out = unsafe { TRANSLATED.execute_func::<(u32, u32), u32>(0, (a, b)) }; let out = TRANSLATED.execute_func::<(u32, u32), u32>(0, (a, b)).unwrap();
(a == b) == (out == 1) (a == b) == (out == 1)
} }
@@ -234,9 +229,9 @@ quickcheck! {
static ref TRANSLATED: TranslatedModule = translate_wat(CODE); static ref TRANSLATED: TranslatedModule = translate_wat(CODE);
} }
let out = unsafe { TRANSLATED.execute_func::<(u32, u32), u32>(0, (a, b)) }; let out = TRANSLATED.execute_func::<(u32, u32), u32>(0, (a, b));
out == (if a == b { a } else { b }) out == Ok(if a == b { a } else { b })
} }
} }
#[test] #[test]
@@ -310,10 +305,10 @@ fn large_function() {
{ {
let translated = translate_wat(code); let translated = translate_wat(code);
translated.disassemble(); translated.disassemble();
let out: u32 = unsafe { translated.execute_func(0, (5, 4, 3, 2, 1, 0)) }; let out: Result<u32, _> = translated.execute_func(0, (5, 4, 3, 2, 1, 0));
out out
}, },
5 Ok(5)
); );
} }
@@ -344,12 +339,9 @@ fn function_read_args_spill_to_stack() {
{ {
let translated = translate_wat(code); let translated = translate_wat(code);
translated.disassemble(); translated.disassemble();
let out: u32 = unsafe {
translated.execute_func(0, (7u32, 6u32, 5u32, 4u32, 3u32, 2u32, 1u32, 0u32)) translated.execute_func(0, (7u32, 6u32, 5u32, 4u32, 3u32, 2u32, 1u32, 0u32))
};
out
}, },
7 Ok(7u32)
); );
} }
@@ -408,18 +400,16 @@ macro_rules! mk_function_write_args_spill_to_stack {
{ {
let translated = translate_wat(&code); let translated = translate_wat(&code);
translated.disassemble(); translated.disassemble();
let out: $typ = unsafe { let out: Result<$typ, _> = translated.execute_func(
translated.execute_func(
0, 0,
( (
11 as $typ, 10 as $typ, 9 as $typ, 8 as $typ, 7 as $typ, 6 as $typ, 11 as $typ, 10 as $typ, 9 as $typ, 8 as $typ, 7 as $typ, 6 as $typ,
5 as $typ, 4 as $typ, 3 as $typ, 2 as $typ, 1 as $typ, 0 as $typ, 5 as $typ, 4 as $typ, 3 as $typ, 2 as $typ, 1 as $typ, 0 as $typ,
), ),
) );
};
out out
}, },
11 Ok(11)
); );
} }
}; };
@@ -461,7 +451,10 @@ fn br_block() {
let translated = translate_wat(code); let translated = translate_wat(code);
translated.disassemble(); translated.disassemble();
assert_eq!(unsafe { translated.execute_func::<(i32, i32), i32>(0, (5, 7)) }, 12); assert_eq!(
translated.execute_func::<(i32, i32), i32>(0, (5, 7)),
Ok(12)
);
} }
// Tests discarding values on the value stack, while // Tests discarding values on the value stack, while
@@ -543,7 +536,7 @@ fn spec_loop() {
let translated = translate_wat(code); let translated = translate_wat(code);
translated.disassemble(); translated.disassemble();
unsafe { translated.execute_func::<(), ()>(0, ()) } translated.execute_func::<(), ()>(0, ()).unwrap();
} }
quickcheck! { quickcheck! {
@@ -620,9 +613,8 @@ quickcheck! {
} }
let n = n as i32; let n = n as i32;
unsafe {
TRANSLATED.execute_func::<(i32,), i32>(0, (n,)) == fac(n) TRANSLATED.execute_func::<(i32,), i32>(0, (n,)) == Ok(fac(n))
}
} }
} }
@@ -659,6 +651,44 @@ fn literals() {
assert_eq!(execute_wat(code, 0, 0), 228); assert_eq!(execute_wat(code, 0, 0), 228);
} }
#[test]
fn wrong_type() {
let code = r#"
(module
(func (param i32) (param i64) (result i32)
(i32.const 228)
)
)
"#;
let translated = translate_wat(code);
assert_eq!(
translated
.execute_func::<_, ()>(0, (0u32, 0u32))
.unwrap_err(),
ExecutionError::TypeMismatch
);
}
#[test]
fn wrong_index() {
let code = r#"
(module
(func (param i32) (param i64) (result i32)
(i32.const 228)
)
)
"#;
let translated = translate_wat(code);
assert_eq!(
translated
.execute_func::<_, ()>(10, (0u32, 0u32))
.unwrap_err(),
ExecutionError::FuncIndexOutOfBounds
);
}
const FIBONACCI: &str = r#" const FIBONACCI: &str = r#"
(module (module
(func $fib (param $n i32) (result i32) (func $fib (param $n i32) (result i32)
@@ -722,15 +752,13 @@ fn fib() {
translated.disassemble(); translated.disassemble();
for x in 0..30 { for x in 0..30 {
unsafe {
assert_eq!( assert_eq!(
translated.execute_func::<_, u32>(0, (x,)), translated.execute_func::<_, u32>(0, (x,)),
fib(x), Ok(fib(x)),
"Failed for x={}", "Failed for x={}",
x x
); );
} }
}
} }
#[test] #[test]
@@ -766,7 +794,7 @@ fn storage() {
let translated = translate_wat(CODE); let translated = translate_wat(CODE);
translated.disassemble(); translated.disassemble();
assert_eq!(unsafe { translated.execute_func::<(), i32>(0, ()) }, 1); assert_eq!(translated.execute_func::<(), i32>(0, ()), Ok(1));
} }
#[bench] #[bench]
@@ -781,7 +809,7 @@ fn bench_fibonacci_run(b: &mut test::Bencher) {
let wasm = wabt::wat2wasm(FIBONACCI).unwrap(); let wasm = wabt::wat2wasm(FIBONACCI).unwrap();
let module = translate(&wasm).unwrap(); let module = translate(&wasm).unwrap();
b.iter(|| unsafe { module.execute_func::<_, u32>(0, (20,)) }); b.iter(|| module.execute_func::<_, u32>(0, (20,)));
} }
#[bench] #[bench]

View File

@@ -1,7 +1,7 @@
use backend::{CodeGenSession, TranslatedCodeSection}; use backend::{CodeGenSession, TranslatedCodeSection};
use error::Error; use error::Error;
use function_body; use function_body;
use module::TranslationContext; use module::FuncTyStore;
#[allow(unused_imports)] // for now #[allow(unused_imports)] // for now
use wasmparser::{ use wasmparser::{
CodeSectionReader, Data, DataSectionReader, Element, ElementSectionReader, Export, CodeSectionReader, Data, DataSectionReader, Element, ElementSectionReader, Export,
@@ -12,11 +12,10 @@ use wasmparser::{
/// Parses the Type section of the wasm module. /// Parses the Type section of the wasm module.
pub fn type_(types_reader: TypeSectionReader) -> Result<Vec<FuncType>, Error> { pub fn type_(types_reader: TypeSectionReader) -> Result<Vec<FuncType>, Error> {
let mut types = vec![]; types_reader
for entry in types_reader { .into_iter()
types.push(entry?); .map(|r| r.map_err(Into::into))
} .collect()
Ok(types)
} }
/// Parses the Import section of the wasm module. /// Parses the Import section of the wasm module.
@@ -29,11 +28,10 @@ pub fn import(imports: ImportSectionReader) -> Result<(), Error> {
/// Parses the Function section of the wasm module. /// Parses the Function section of the wasm module.
pub fn function(functions: FunctionSectionReader) -> Result<Vec<u32>, Error> { pub fn function(functions: FunctionSectionReader) -> Result<Vec<u32>, Error> {
let mut func_ty_indicies = vec![]; functions
for entry in functions { .into_iter()
func_ty_indicies.push(entry?); .map(|r| r.map_err(Into::into))
} .collect()
Ok(func_ty_indicies)
} }
/// Parses the Table section of the wasm module. /// Parses the Table section of the wasm module.
@@ -85,7 +83,7 @@ pub fn element(elements: ElementSectionReader) -> Result<(), Error> {
/// Parses the Code section of the wasm module. /// Parses the Code section of the wasm module.
pub fn code( pub fn code(
code: CodeSectionReader, code: CodeSectionReader,
translation_ctx: &TranslationContext, translation_ctx: &FuncTyStore,
memory: Option<&mut [u8]>, memory: Option<&mut [u8]>,
) -> Result<TranslatedCodeSection, Error> { ) -> Result<TranslatedCodeSection, Error> {
let func_count = code.get_count(); let func_count = code.get_count();