Emit runtime type checks in legalizer.rs (#112)
* Emit runtime type checks in legalizer.rs
This commit is contained in:
committed by
Jakob Stoklund Olesen
parent
464f2625d4
commit
98f822f347
@@ -11,16 +11,116 @@ from __future__ import absolute_import
|
||||
from srcgen import Formatter
|
||||
from base import legalize, instructions
|
||||
from cdsl.ast import Var
|
||||
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, ConstrainTVsEqual,\
|
||||
ConstrainTVInTypeset
|
||||
from unique_table import UniqueTable
|
||||
from gen_instr import gen_typesets_table
|
||||
from cdsl.typevar import TypeVar
|
||||
|
||||
try:
|
||||
from typing import Sequence # noqa
|
||||
from typing import Sequence, List, Dict # noqa
|
||||
from cdsl.isa import TargetISA # noqa
|
||||
from cdsl.ast import Def # noqa
|
||||
from cdsl.xform import XForm, XFormGroup # noqa
|
||||
from cdsl.typevar import TypeSet # noqa
|
||||
from cdsl.ti import TypeConstraint # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def get_runtime_typechecks(xform):
|
||||
# type: (XForm) -> List[TypeConstraint]
|
||||
"""
|
||||
Given a XForm build a list of runtime type checks neccessary to determine
|
||||
if it applies. We have 2 types of runtime checks:
|
||||
1) typevar tv belongs to typeset T - needed for free tvs whose
|
||||
typeset is constrainted by their use in the dst pattern
|
||||
|
||||
2) tv1 == tv2 where tv1 and tv2 are derived TVs - caused by unification
|
||||
of non-bijective functions
|
||||
"""
|
||||
check_l = [] # type: List[TypeConstraint]
|
||||
|
||||
# 1) Perform ti only on the source RTL. Accumulate any free tvs that have a
|
||||
# different inferred type in src, compared to the type inferred for both
|
||||
# src and dst.
|
||||
symtab = {} # type: Dict[Var, Var]
|
||||
src_copy = xform.src.copy(symtab)
|
||||
src_typenv = get_type_env(ti_rtl(src_copy, TypeEnv()))
|
||||
|
||||
for v in xform.ti.vars:
|
||||
if not v.has_free_typevar():
|
||||
continue
|
||||
|
||||
# In rust the local variable containing a free TV associated with var v
|
||||
# has name typeof_v. We rely on the python TVs having the same name.
|
||||
assert "typeof_{}".format(v) == xform.ti[v].name
|
||||
|
||||
if v not in symtab:
|
||||
# We can have singleton vars defined only on dst. Ignore them
|
||||
assert v.get_typevar().singleton_type() is not None
|
||||
continue
|
||||
|
||||
src_ts = src_typenv[symtab[v]].get_typeset()
|
||||
xform_ts = xform.ti[v].get_typeset()
|
||||
|
||||
assert xform_ts.issubset(src_ts)
|
||||
if src_ts != xform_ts:
|
||||
check_l.append(ConstrainTVInTypeset(xform.ti[v], xform_ts))
|
||||
|
||||
# 2,3) Add any constraints that appear in xform.ti
|
||||
check_l.extend(xform.ti.constraints)
|
||||
|
||||
return check_l
|
||||
|
||||
|
||||
def emit_runtime_typecheck(check, fmt, type_sets):
|
||||
# type: (TypeConstraint, Formatter, UniqueTable) -> None
|
||||
"""
|
||||
Emit rust code for the given check.
|
||||
"""
|
||||
def build_derived_expr(tv):
|
||||
# type: (TypeVar) -> str
|
||||
if not tv.is_derived:
|
||||
assert tv.name.startswith('typeof_')
|
||||
return "Some({})".format(tv.name)
|
||||
|
||||
base_exp = build_derived_expr(tv.base)
|
||||
if (tv.derived_func == TypeVar.LANEOF):
|
||||
return "{}.map(|t: Type| -> t.lane_type())".format(base_exp)
|
||||
elif (tv.derived_func == TypeVar.ASBOOL):
|
||||
return "{}.map(|t: Type| -> t.as_bool())".format(base_exp)
|
||||
elif (tv.derived_func == TypeVar.HALFWIDTH):
|
||||
return "{}.and_then(|t: Type| -> t.half_width())".format(base_exp)
|
||||
elif (tv.derived_func == TypeVar.DOUBLEWIDTH):
|
||||
return "{}.and_then(|t: Type| -> t.double_width())"\
|
||||
.format(base_exp)
|
||||
elif (tv.derived_func == TypeVar.HALFVECTOR):
|
||||
return "{}.and_then(|t: Type| -> t.half_vector())".format(base_exp)
|
||||
elif (tv.derived_func == TypeVar.DOUBLEVECTOR):
|
||||
return "{}.and_then(|t: Type| -> t.by(2))".format(base_exp)
|
||||
else:
|
||||
assert False, "Unknown derived function {}".format(tv.derived_func)
|
||||
|
||||
if (isinstance(check, ConstrainTVInTypeset)):
|
||||
tv = check.tv.name
|
||||
if check.ts not in type_sets.index:
|
||||
type_sets.add(check.ts)
|
||||
ts = type_sets.index[check.ts]
|
||||
|
||||
fmt.comment("{} must belong to {}".format(tv, check.ts))
|
||||
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
|
||||
'};'):
|
||||
fmt.line('return false;')
|
||||
elif (isinstance(check, ConstrainTVsEqual)):
|
||||
tv1 = build_derived_expr(check.tv1)
|
||||
tv2 = build_derived_expr(check.tv2)
|
||||
with fmt.indented('if {} != {} {{'.format(tv1, tv2), '};'):
|
||||
fmt.line('return false;')
|
||||
else:
|
||||
assert False, "Unknown check {}".format(check)
|
||||
|
||||
|
||||
def unwrap_inst(iref, node, fmt):
|
||||
# type: (str, Def, Formatter) -> bool
|
||||
"""
|
||||
@@ -183,8 +283,8 @@ def emit_dst_inst(node, fmt):
|
||||
fmt.line('pos.next_inst();')
|
||||
|
||||
|
||||
def gen_xform(xform, fmt):
|
||||
# type: (XForm, Formatter) -> None
|
||||
def gen_xform(xform, fmt, type_sets):
|
||||
# type: (XForm, Formatter, UniqueTable) -> None
|
||||
"""
|
||||
Emit code for `xform`, assuming the the opcode of xform's root instruction
|
||||
has already been matched.
|
||||
@@ -203,6 +303,10 @@ def gen_xform(xform, fmt):
|
||||
instp = xform.src.rtl[0].expr.inst_predicate()
|
||||
assert instp is None, "Instruction predicates not supported in legalizer"
|
||||
|
||||
# Emit any runtime checks.
|
||||
for check in get_runtime_typechecks(xform):
|
||||
emit_runtime_typecheck(check, fmt, type_sets)
|
||||
|
||||
# Emit the destination pattern.
|
||||
for dst in xform.dst.rtl:
|
||||
emit_dst_inst(dst, fmt)
|
||||
@@ -213,8 +317,8 @@ def gen_xform(xform, fmt):
|
||||
fmt.line('assert_eq!(pos.remove_inst(), inst);')
|
||||
|
||||
|
||||
def gen_xform_group(xgrp, fmt):
|
||||
# type: (XFormGroup, Formatter) -> None
|
||||
def gen_xform_group(xgrp, fmt, type_sets):
|
||||
# type: (XFormGroup, Formatter, UniqueTable) -> None
|
||||
fmt.doc_comment("Legalize the instruction pointed to by `pos`.")
|
||||
fmt.line('#[allow(unused_variables,unused_assignments)]')
|
||||
with fmt.indented(
|
||||
@@ -231,7 +335,7 @@ def gen_xform_group(xgrp, fmt):
|
||||
inst = xform.src.rtl[0].expr.inst
|
||||
with fmt.indented(
|
||||
'Opcode::{} => {{'.format(inst.camel_name), '}'):
|
||||
gen_xform(xform, fmt)
|
||||
gen_xform(xform, fmt, type_sets)
|
||||
# We'll assume there are uncovered opcodes.
|
||||
fmt.line('_ => return false,')
|
||||
fmt.line('true')
|
||||
@@ -240,6 +344,11 @@ def gen_xform_group(xgrp, fmt):
|
||||
def generate(isas, out_dir):
|
||||
# type: (Sequence[TargetISA], str) -> None
|
||||
fmt = Formatter()
|
||||
gen_xform_group(legalize.narrow, fmt)
|
||||
gen_xform_group(legalize.expand, fmt)
|
||||
# Table of TypeSet instances
|
||||
type_sets = UniqueTable()
|
||||
|
||||
gen_xform_group(legalize.narrow, fmt, type_sets)
|
||||
gen_xform_group(legalize.expand, fmt, type_sets)
|
||||
|
||||
gen_typesets_table(fmt, type_sets)
|
||||
fmt.update_file('legalizer.rs', out_dir)
|
||||
|
||||
Reference in New Issue
Block a user