Allow for multiple legalization patterns for the same opcode.
Each input pattern can have a predicate in addition to an opcode being matched. When an opcode has multiple patterns, execute the first pattern with a true predicate. The predicates can be type checks or instruction predicates checking immediate fields.
This commit is contained in:
@@ -9,6 +9,7 @@ the input instruction.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from srcgen import Formatter
|
||||
from collections import defaultdict
|
||||
from base import instructions
|
||||
from cdsl.ast import Var
|
||||
from cdsl.ti import ti_rtl, TypeEnv, get_type_env, TypesEqual,\
|
||||
@@ -18,7 +19,7 @@ from gen_instr import gen_typesets_table
|
||||
from cdsl.typevar import TypeVar
|
||||
|
||||
try:
|
||||
from typing import Sequence, List, Dict, Set # noqa
|
||||
from typing import Sequence, List, Dict, Set, DefaultDict # noqa
|
||||
from cdsl.isa import TargetISA # noqa
|
||||
from cdsl.ast import Def # noqa
|
||||
from cdsl.xform import XForm, XFormGroup # noqa
|
||||
@@ -78,6 +79,11 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
||||
# type: (TypeConstraint, Formatter, UniqueTable) -> None
|
||||
"""
|
||||
Emit rust code for the given check.
|
||||
|
||||
The emitted code is a statement redefining the `predicate` variable like
|
||||
this:
|
||||
|
||||
let predicate = predicate && ...
|
||||
"""
|
||||
def build_derived_expr(tv):
|
||||
# type: (TypeVar) -> str
|
||||
@@ -116,33 +122,26 @@ def emit_runtime_typecheck(check, fmt, type_sets):
|
||||
if check.ts not in type_sets.index:
|
||||
type_sets.add(check.ts)
|
||||
ts = type_sets.index[check.ts]
|
||||
|
||||
fmt.comment("{} must belong to {}".format(tv, check.ts))
|
||||
with fmt.indented('if !TYPE_SETS[{}].contains({}) {{'.format(ts, tv),
|
||||
'};'):
|
||||
fmt.line('return false;')
|
||||
fmt.format(
|
||||
'let predicate = predicate && TYPE_SETS[{}].contains({});',
|
||||
ts, tv)
|
||||
elif (isinstance(check, TypesEqual)):
|
||||
with fmt.indented('{', '};'):
|
||||
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
|
||||
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
|
||||
|
||||
fmt.comment('On overflow constraint doesn\'t appply')
|
||||
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
|
||||
fmt.line('return false;')
|
||||
|
||||
with fmt.indented('if a != b {', '};'):
|
||||
fmt.line('return false;')
|
||||
with fmt.indented(
|
||||
'let predicate = predicate && match ({}, {}) {{'
|
||||
.format(build_derived_expr(check.tv1),
|
||||
build_derived_expr(check.tv2)), '};'):
|
||||
fmt.line('(Some(a), Some(b)) => a == b,')
|
||||
fmt.comment('On overflow, constraint doesn\'t appply')
|
||||
fmt.line('_ => false,')
|
||||
elif (isinstance(check, WiderOrEq)):
|
||||
with fmt.indented('{', '};'):
|
||||
fmt.line('let a = {};'.format(build_derived_expr(check.tv1)))
|
||||
fmt.line('let b = {};'.format(build_derived_expr(check.tv2)))
|
||||
|
||||
fmt.comment('On overflow constraint doesn\'t appply')
|
||||
with fmt.indented('if a.is_none() || b.is_none() {', '};'):
|
||||
fmt.line('return false;')
|
||||
|
||||
with fmt.indented('if !a.wider_or_equal(b) {', '};'):
|
||||
fmt.line('return false;')
|
||||
with fmt.indented(
|
||||
'let predicate = predicate && match ({}, {}) {{'
|
||||
.format(build_derived_expr(check.tv1),
|
||||
build_derived_expr(check.tv2)), '};'):
|
||||
fmt.line('(Some(a), Some(b)) => a.wider_or_equal(b),')
|
||||
fmt.comment('On overflow, constraint doesn\'t appply')
|
||||
fmt.line('_ => false,')
|
||||
else:
|
||||
assert False, "Unknown check {}".format(check)
|
||||
|
||||
@@ -216,14 +215,12 @@ def unwrap_inst(iref, node, fmt):
|
||||
replace_inst = True
|
||||
else:
|
||||
# Boring case: Detach the result values, capture them in locals.
|
||||
fmt.comment('Detaching results.')
|
||||
for d in node.defs:
|
||||
fmt.line('let {};'.format(d))
|
||||
with fmt.indented('{', '}'):
|
||||
fmt.line('let r = dfg.inst_results(inst);')
|
||||
for i in range(len(node.defs)):
|
||||
fmt.line('{} = r[{}];'.format(node.defs[i], i))
|
||||
fmt.line('dfg.clear_results(inst);')
|
||||
for d in node.defs:
|
||||
if d.has_free_typevar():
|
||||
fmt.line(
|
||||
@@ -312,7 +309,7 @@ def emit_dst_inst(node, fmt):
|
||||
def gen_xform(xform, fmt, type_sets):
|
||||
# type: (XForm, Formatter, UniqueTable) -> None
|
||||
"""
|
||||
Emit code for `xform`, assuming the the opcode of xform's root instruction
|
||||
Emit code for `xform`, assuming that the opcode of xform's root instruction
|
||||
has already been matched.
|
||||
|
||||
`inst: Inst` is the variable to be replaced. It is pointed to by `pos:
|
||||
@@ -323,24 +320,33 @@ def gen_xform(xform, fmt, type_sets):
|
||||
# variables.
|
||||
replace_inst = unwrap_inst('inst', xform.src.rtl[0], fmt)
|
||||
|
||||
# We could support instruction predicates, but not yet. Should we just
|
||||
# return false if it fails? What about multiple patterns with different
|
||||
# predicates for the same opcode?
|
||||
# Check instruction predicate and emit type checks.
|
||||
instp = xform.src.rtl[0].expr.inst_predicate()
|
||||
assert instp is None, "Instruction predicates not supported in legalizer"
|
||||
# TODO: The instruction predicate should be evaluated with all the inst
|
||||
# immediate fields available. Probably by unwrap_inst().
|
||||
fmt.format('let predicate = {};',
|
||||
instp.rust_predicate(0) if instp else 'true')
|
||||
|
||||
# Emit any runtime checks.
|
||||
for check in get_runtime_typechecks(xform):
|
||||
emit_runtime_typecheck(check, fmt, type_sets)
|
||||
|
||||
# Emit the destination pattern.
|
||||
for dst in xform.dst.rtl:
|
||||
emit_dst_inst(dst, fmt)
|
||||
# Guard the actual expansion by `predicate`.
|
||||
with fmt.indented('if predicate {', '}'):
|
||||
# If we're going to delete `inst`, we need to detach its results first
|
||||
# so they can be reattached during pattern expansion.
|
||||
if not replace_inst:
|
||||
fmt.line('dfg.clear_results(inst);')
|
||||
|
||||
# Delete the original instruction if we didn't have an opportunity to
|
||||
# replace it.
|
||||
if not replace_inst:
|
||||
fmt.line('assert_eq!(pos.remove_inst(), inst);')
|
||||
# Emit the destination pattern.
|
||||
for dst in xform.dst.rtl:
|
||||
emit_dst_inst(dst, fmt)
|
||||
|
||||
# Delete the original instruction if we didn't have an opportunity to
|
||||
# replace it.
|
||||
if not replace_inst:
|
||||
fmt.line('assert_eq!(pos.remove_inst(), inst);')
|
||||
fmt.line('return true;')
|
||||
|
||||
|
||||
def gen_xform_group(xgrp, fmt, type_sets):
|
||||
@@ -358,19 +364,27 @@ def gen_xform_group(xgrp, fmt, type_sets):
|
||||
# pointing at an instruction.
|
||||
fmt.line('let inst = pos.current_inst().expect("need instruction");')
|
||||
|
||||
# Group the xforms by opcode so we can generate a big switch.
|
||||
# Preserve ordering.
|
||||
xforms = defaultdict(list) # type: DefaultDict[str, List[XForm]]
|
||||
for xform in xgrp.xforms:
|
||||
inst = xform.src.rtl[0].expr.inst
|
||||
xforms[inst.camel_name].append(xform)
|
||||
|
||||
with fmt.indented('match dfg[inst].opcode() {', '}'):
|
||||
for xform in xgrp.xforms:
|
||||
inst = xform.src.rtl[0].expr.inst
|
||||
for camel_name in sorted(xforms.keys()):
|
||||
with fmt.indented(
|
||||
'ir::Opcode::{} => {{'.format(inst.camel_name), '}'):
|
||||
gen_xform(xform, fmt, type_sets)
|
||||
'ir::Opcode::{} => {{'.format(camel_name), '}'):
|
||||
for xform in xforms[camel_name]:
|
||||
gen_xform(xform, fmt, type_sets)
|
||||
# We'll assume there are uncovered opcodes.
|
||||
if xgrp.chain:
|
||||
fmt.format('_ => return {}(dfg, cfg, pos),',
|
||||
xgrp.chain.rust_name())
|
||||
else:
|
||||
fmt.line('_ => return false,')
|
||||
fmt.line('true')
|
||||
fmt.line('_ => {},')
|
||||
|
||||
# If we fall through, nothing was expanded. Call the chain if any.
|
||||
if xgrp.chain:
|
||||
fmt.format('{}(dfg, cfg, pos)', xgrp.chain.rust_name())
|
||||
else:
|
||||
fmt.line('false')
|
||||
|
||||
|
||||
def gen_isa(isa, fmt, shared_groups):
|
||||
|
||||
@@ -49,39 +49,27 @@ def typeset_check(v, ts):
|
||||
# type: (Var, TypeSet) -> CheckProducer
|
||||
return lambda typesets: format_check(
|
||||
typesets,
|
||||
'if !TYPE_SETS[{}].contains(typeof_{}) ' +
|
||||
'{{\n return false;\n}};\n', ts, v)
|
||||
'let predicate = predicate && TYPE_SETS[{}].contains(typeof_{});\n',
|
||||
ts, v)
|
||||
|
||||
|
||||
def equiv_check(tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> CheckProducer
|
||||
# type: (str, str) -> CheckProducer
|
||||
return lambda typesets: format_check(
|
||||
typesets,
|
||||
'{{\n' +
|
||||
' let a = {};\n' +
|
||||
' let b = {};\n' +
|
||||
' if a.is_none() || b.is_none() {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
' if a != b {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
'let predicate = predicate && match ({}, {}) {{\n'
|
||||
' (Some(a), Some(b)) => a == b,\n'
|
||||
' _ => false,\n'
|
||||
'}};\n', tv1, tv2)
|
||||
|
||||
|
||||
def wider_check(tv1, tv2):
|
||||
# type: (TypeVar, TypeVar) -> CheckProducer
|
||||
# type: (str, str) -> CheckProducer
|
||||
return lambda typesets: format_check(
|
||||
typesets,
|
||||
'{{\n' +
|
||||
' let a = {};\n' +
|
||||
' let b = {};\n' +
|
||||
' if a.is_none() || b.is_none() {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
' if !a.wider_or_equal(b) {{\n' +
|
||||
' return false;\n' +
|
||||
' }};\n' +
|
||||
'let predicate = predicate && match ({}, {}) {{\n'
|
||||
' (Some(a), Some(b)) => a.wider_or_equal(b),\n'
|
||||
' _ => false,\n'
|
||||
'}};\n', tv1, tv2)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user