Emit runtime type checks in legalizer.rs (#112)

* Emit runtime type checks in legalizer.rs
This commit is contained in:
d1m0
2017-07-10 15:28:32 -07:00
committed by Jakob Stoklund Olesen
parent 464f2625d4
commit 98f822f347
9 changed files with 494 additions and 69 deletions

View File

@@ -11,16 +11,116 @@ from __future__ import absolute_import
from srcgen import Formatter
from base import legalize, instructions
from cdsl.ast import Var
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\
ConstrainTVInTypeset
from unique_table import UniqueTable
from gen_instr import gen_typesets_table
from cdsl.typevar import TypeVar
try:
from typing import Sequence # noqa
from typing import Sequence, List, Dict # noqa
from cdsl.isa import TargetISA # noqa
from cdsl.ast import Def # noqa
from cdsl.xform import XForm, XFormGroup # noqa
from cdsl.typevar import TypeSet # noqa
from cdsl.ti import TypeConstraint # noqa
except ImportError:
pass
def get_runtime_typechecks(xform):
# type: (XForm) -> List[TypeConstraint]
"""
Given a XForm build a list of runtime type checks neccessary to determine
if it applies. We have 2 types of runtime checks:
1) typevar tv belongs to typeset T - needed for free tvs whose
typeset is constrainted by their use in the dst pattern
2) tv1 == tv2 where tv1 and tv2 are derived TVs - caused by unification
of non-bijective functions
"""
check_l = [] # type: List[TypeConstraint]
# 1) Perform ti only on the source RTL. Accumulate any free tvs that have a
# different inferred type in src, compared to the type inferred for both
# src and dst.
symtab = {} # type: Dict[Var, Var]
src_copy = xform.src.copy(symtab)
src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv()))
for v in xform.ti.vars:
if not v.has_free_typevar():
continue
# In rust the local variable containing a free TV associated with var v
# has name typeof_v. We rely on the python TVs having the same name.
assert "typeof_{}".format(v) == xform.ti[v].name
if v not in symtab:
# We can have singleton vars defined only on dst. Ignore them
assert v.get_typevar().singleton_type() is not None
continue
src_ts = src_typenv[symtab[v]].get_typeset()
xform_ts = xform.ti[v].get_typeset()
assert xform_ts.issubset(src_ts)
if src_ts != xform_ts:
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts))
# 2,3) Add any constraints that appear in xform.ti
check_l.extend(xform.ti.constraints)
return check_l
def emit_runtime_typecheck(check, fmt, type_sets):
# type: (TypeConstraint, Formatter, UniqueTable) -> None
"""
Emit rust code for the given check.
"""
def build_derived_expr(tv):
# type: (TypeVar) -> str
if not tv.is_derived:
assert tv.name.startswith('typeof_')
return "Some({})".format(tv.name)
base_exp = build_derived_expr(tv.base)
if (tv.derived_func == TypeVar.LANEOF):
return "{}.map(|t: Type| -> t.lane_type())".format(base_exp)
elif (tv.derived_func == TypeVar.ASBOOL):
return "{}.map(|t: Type| -> t.as_bool())".format(base_exp)
elif (tv.derived_func == TypeVar.HALFWIDTH):
return "{}.and_then(|t: Type| -> t.half_width())".format(base_exp)
elif (tv.derived_func == TypeVar.DOUBLEWIDTH):
return "{}.and_then(|t: Type| -> t.double_width())"\
.format(base_exp)
elif (tv.derived_func == TypeVar.HALFVECTOR):
return "{}.and_then(|t: Type| -> t.half_vector())".format(base_exp)
elif (tv.derived_func == TypeVar.DOUBLEVECTOR):
return "{}.and_then(|t: Type| -> t.by(2))".format(base_exp)
else:
assert False, "Unknown derived function {}".format(tv.derived_func)
if (isinstance(check, ConstrainTVInTypeset)):
tv = check.tv.name
if check.ts not in type_sets.index:
type_sets.add(check.ts)
ts = type_sets.index[check.ts]
fmt.comment("{} must belong to {}".format(tv, check.ts))
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
'};'):
fmt.line('return false;')
elif (isinstance(check, ConstrainTVsEqual)):
tv1 = build_derived_expr(check.tv1)
tv2 = build_derived_expr(check.tv2)
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'):
fmt.line('return false;')
else:
assert False, "Unknown check {}".format(check)
def unwrap_inst(iref, node, fmt):
# type: (str, Def, Formatter) -> bool
"""
@@ -183,8 +283,8 @@ def emit_dst_inst(node, fmt):
fmt.line('pos.next_inst();')
def gen_xform(xform, fmt):
# type: (XForm, Formatter) -> None
def gen_xform(xform, fmt, type_sets):
# type: (XForm, Formatter, UniqueTable) -> None
"""
Emit code for `xform`, assuming the the opcode of xform's root instruction
has already been matched.
@@ -203,6 +303,10 @@ def gen_xform(xform, fmt):
instp = xform.src.rtl[0].expr.inst_predicate()
assert instp is None, "Instruction predicates not supported in legalizer"
# Emit any runtime checks.
for check in get_runtime_typechecks(xform):
emit_runtime_typecheck(check, fmt, type_sets)
# Emit the destination pattern.
for dst in xform.dst.rtl:
emit_dst_inst(dst, fmt)
@@ -213,8 +317,8 @@ def gen_xform(xform, fmt):
fmt.line('assert_eq!(pos.remove_inst(), inst);')
def gen_xform_group(xgrp, fmt):
# type: (XFormGroup, Formatter) -> None
def gen_xform_group(xgrp, fmt, type_sets):
# type: (XFormGroup, Formatter, UniqueTable) -> None
fmt.doc_comment("Legalize the instruction pointed to by `pos`.")
fmt.line('#[allow(unused_variables,unused_assignments)]')
with fmt.indented(
@@ -231,7 +335,7 @@ def gen_xform_group(xgrp, fmt):
inst = xform.src.rtl[0].expr.inst
with fmt.indented(
'Opcode::{} => {{'.format(inst.camel_name), '}'):
gen_xform(xform, fmt)
gen_xform(xform, fmt, type_sets)
# We'll assume there are uncovered opcodes.
fmt.line('_ => return false,')
fmt.line('true')
@@ -240,6 +344,11 @@ def gen_xform_group(xgrp, fmt):
def generate(isas, out_dir):
# type: (Sequence[TargetISA], str) -> None
fmt = Formatter()
gen_xform_group(legalize.narrow, fmt)
gen_xform_group(legalize.expand, fmt)
# Table of TypeSet instances
type_sets = UniqueTable()
gen_xform_group(legalize.narrow, fmt, type_sets)
gen_xform_group(legalize.expand, fmt, type_sets)
gen_typesets_table(fmt, type_sets)
fmt.update_file('legalizer.rs', out_dir)