cranelift-isle: Add "partial" flag for constructors (#5392)

* cranelift-isle: Add "partial" flag for constructors

Instead of tying fallibility of constructors to whether they're either
internal or pure, this commit assumes all constructors are infallible
unless tagged otherwise with a "partial" flag.

Internal constructors without the "partial" flag are not allowed to use
constructors which have the "partial" flag on the right-hand side of any
rules, because they have no way to report last-minute match failures.

Multi-constructors should never be "partial"; they report match failures
with an empty iterator instead. In turn this means you can't use partial
constructors on the right-hand side of internal multi-constructor rules.
However, you can use the same constructors on the left-hand side with
`if` or `if-let` instead.

In many cases, ISLE can already trivially prove that an internal
constructor always returns `Some`. With this commit, those cases are
largely unchanged, except for removing all the `Option`s and `Some`s
from the generated code for those terms.

However, for internal non-partial constructors where ISLE could not
prove that, it now emits an `unreachable!` panic as the last-resort,
instead of returning `None` like it used to do. Among the existing
backends, here's how many constructors have these panic cases:

- x64: 14% (53/374)
- aarch64: 15% (41/277)
- riscv64: 23% (26/114)
- s390x: 47% (268/567)

It's often possible to rewrite rules so that ISLE can tell the panic can
never be hit. Just ensure that there's a lowest-priority rule which has
no constraints on the left-hand side.

But in many of these constructors, it's difficult to statically prove
the unhandled cases are unreachable because that's only down to
knowledge about how they're called or other preconditions.

So this commit does not try to enforce that all terms have a last-resort
fallback rule.

* Check term flags while translating expressions

Instead of doing it in a separate pass afterward.

This involved threading all the term flags (pure, multi, partial)
through the recursive `translate_expr` calls, so I extracted the flags
to a new struct so they can all be passed together.

* Validate multi-term usage

Now that I've threaded the flags through `translate_expr`, it's easy to
check this case too, so let's just do it.

* Extract `ReturnKind` to use in `ExternalSig`

There are only three legal states for the combination of `multi` and
`infallible`, so replace those fields of `ExternalSig` with a
three-state enum.

* Remove `Option` wrapper from multi-extractors too

If we'd had any external multi-constructors this would correct their
signatures as well.

* Update ISLE tests

* Tag prelude constructors as pure where appropriate

I believe the only reason these weren't marked `pure` before was because
that would have implied that they're also partial. Now that those two
states are specified separately we apply this flag more places.

* Fix my changes to aarch64 `lower_bmask` and `imm` terms
This commit is contained in:
Jamey Sharp
2022-12-07 17:16:03 -08:00
committed by GitHub
parent c9527e0af6
commit 8726eeefb3
26 changed files with 433 additions and 358 deletions

View File

@@ -4,7 +4,7 @@
(decl get_a (A) u32)
(extern extractor get_a get_a)
(decl pure u32_pure (u32) u32)
(decl pure partial u32_pure (u32) u32)
(extern constructor u32_pure u32_pure)
(decl entry (u32) u32)

View File

@@ -1,12 +1,12 @@
(type u32 (primitive u32))
(decl pure A (u32 u32) u32)
(decl pure partial A (u32 u32) u32)
(extern constructor A A)
(decl B (u32 u32) u32)
(extern extractor B B)
(decl C (u32 u32 u32 u32) u32)
(decl partial C (u32 u32 u32 u32) u32)
(decl pure predicate () u32)
(rule (predicate) 1)

View File

@@ -25,22 +25,28 @@ impl multi_constructor::ContextIter for It {
impl multi_constructor::Context for Context {
type etor_C_iter = It;
fn etor_C(&mut self, value: u32) -> Option<It> {
Some(It { i: 0, limit: value })
fn etor_C(&mut self, value: u32) -> It {
It { i: 0, limit: value }
}
type ctor_B_iter = multi_constructor::ContextIterWrapper<u32, std::vec::IntoIter<u32>, Context>;
fn ctor_B(&mut self, value: u32) -> Option<Self::ctor_B_iter> {
Some((0..value).rev().collect::<Vec<_>>().into_iter().into())
fn ctor_B(&mut self, value: u32) -> Self::ctor_B_iter {
(0..value).rev().collect::<Vec<_>>().into_iter().into()
}
}
struct IterWithContext<'a, Item, I: multi_constructor::ContextIter<Output = Item, Context = Context>> {
struct IterWithContext<
'a,
Item,
I: multi_constructor::ContextIter<Output = Item, Context = Context>,
> {
ctx: &'a mut Context,
it: I,
}
impl<'a, Item, I: multi_constructor::ContextIter<Output = Item, Context = Context>> Iterator for IterWithContext<'a, Item, I> {
impl<'a, Item, I: multi_constructor::ContextIter<Output = Item, Context = Context>> Iterator
for IterWithContext<'a, Item, I>
{
type Item = Item;
fn next(&mut self) -> Option<Self::Item> {
self.it.next(self.ctx)
@@ -49,9 +55,17 @@ impl<'a, Item, I: multi_constructor::ContextIter<Output = Item, Context = Contex
fn main() {
let mut ctx = Context;
let l1 = multi_constructor::constructor_A(&mut ctx, 10).unwrap();
let l2 = multi_constructor::constructor_D(&mut ctx, 5).unwrap();
let l1 = IterWithContext { ctx: &mut ctx, it: l1 }.collect::<Vec<_>>();
let l2 = IterWithContext { ctx: &mut ctx, it: l2 }.collect::<Vec<_>>();
let l1 = multi_constructor::constructor_A(&mut ctx, 10);
let l2 = multi_constructor::constructor_D(&mut ctx, 5);
let l1 = IterWithContext {
ctx: &mut ctx,
it: l1,
}
.collect::<Vec<_>>();
let l2 = IterWithContext {
ctx: &mut ctx,
it: l2,
}
.collect::<Vec<_>>();
println!("l1 = {:?} l2 = {:?}", l1, l2);
}

View File

@@ -33,8 +33,8 @@ impl multi_extractor::ContextIter for It {
struct Context;
impl multi_extractor::Context for Context {
type e1_etor_iter = It;
fn e1_etor(&mut self, arg0: u32) -> Option<It> {
Some(It { i: 0, arg: arg0 })
fn e1_etor(&mut self, arg0: u32) -> It {
It { i: 0, arg: arg0 }
}
}

View File

@@ -1,6 +1,6 @@
(type i64 (primitive i64))
(decl X (i64) i64)
(decl partial X (i64) i64)
(rule (X -1) -2)
(rule (X -2) -3)
(rule (X 0x7fff_ffff_ffff_ffff) 0x8000_0000_0000_0000)
@@ -8,7 +8,7 @@
(type i128 (primitive i128))
(decl Y (i128) i128)
(decl partial Y (i128) i128)
(rule (Y 0x1000_0000_0000_0000_1234_5678_9abc_def0) -1)
(rule (Y 0xffff_ffff_ffff_ffff_ffff_ffff_ffff_ffff) 3)

View File

@@ -7,21 +7,21 @@ impl let_shadowing::Context for Context {}
fn main() {
let mut ctx = Context;
assert_eq!(Some(20), let_shadowing::constructor_test1(&mut ctx, 20));
assert_eq!(Some(97), let_shadowing::constructor_test1(&mut ctx, 97));
assert_eq!(20, let_shadowing::constructor_test1(&mut ctx, 20));
assert_eq!(97, let_shadowing::constructor_test1(&mut ctx, 97));
assert_eq!(Some(20), let_shadowing::constructor_test2(&mut ctx, 20));
assert_eq!(Some(97), let_shadowing::constructor_test2(&mut ctx, 97));
assert_eq!(20, let_shadowing::constructor_test2(&mut ctx, 20));
assert_eq!(97, let_shadowing::constructor_test2(&mut ctx, 97));
assert_eq!(Some(20), let_shadowing::constructor_test3(&mut ctx, 20));
assert_eq!(Some(97), let_shadowing::constructor_test3(&mut ctx, 97));
assert_eq!(20, let_shadowing::constructor_test3(&mut ctx, 20));
assert_eq!(97, let_shadowing::constructor_test3(&mut ctx, 97));
assert_eq!(Some(23), let_shadowing::constructor_test4(&mut ctx, 20));
assert_eq!(Some(23), let_shadowing::constructor_test4(&mut ctx, 97));
assert_eq!(23, let_shadowing::constructor_test4(&mut ctx, 20));
assert_eq!(23, let_shadowing::constructor_test4(&mut ctx, 97));
assert_eq!(Some(20), let_shadowing::constructor_test5(&mut ctx, 20));
assert_eq!(Some(97), let_shadowing::constructor_test5(&mut ctx, 97));
assert_eq!(20, let_shadowing::constructor_test5(&mut ctx, 20));
assert_eq!(97, let_shadowing::constructor_test5(&mut ctx, 97));
assert_eq!(Some(20), let_shadowing::constructor_test6(&mut ctx, 20));
assert_eq!(Some(97), let_shadowing::constructor_test6(&mut ctx, 97));
assert_eq!(20, let_shadowing::constructor_test6(&mut ctx, 20));
assert_eq!(97, let_shadowing::constructor_test6(&mut ctx, 97));
}

View File

@@ -83,6 +83,8 @@ pub struct Decl {
/// extractor or a constructor that matches multiple times, or
/// produces multiple values.
pub multi: bool,
/// Whether this term's constructor can fail to match.
pub partial: bool,
pub pos: Pos,
}

View File

@@ -2,10 +2,10 @@
use crate::ir::{ExprInst, InstId, PatternInst, Value};
use crate::log;
use crate::sema::ExternalSig;
use crate::sema::{TermEnv, TermId, Type, TypeEnv, TypeId, Variant};
use crate::sema::{ExternalSig, ReturnKind, TermEnv, TermId, Type, TypeEnv, TypeId, Variant};
use crate::trie::{TrieEdge, TrieNode, TrieSymbol};
use crate::{StableMap, StableSet};
use std::borrow::Cow;
use std::collections::BTreeMap;
use std::fmt::Write;
@@ -111,14 +111,21 @@ impl<'a> Codegen<'a> {
close_paren = if sig.ret_tys.len() != 1 { ")" } else { "" },
);
let ret_ty = match (sig.multi, sig.infallible) {
(false, false) => format!("Option<{}>", ret_tuple),
(false, true) => format!("{}", ret_tuple),
(true, false) => format!("Option<Self::{}_iter>", sig.func_name),
_ => panic!(
"Unsupported multiplicity/infallible combo: {:?}, {}",
sig.multi, sig.infallible
),
if sig.ret_kind == ReturnKind::Iterator {
writeln!(
code,
"{indent}type {name}_iter: ContextIter<Context = Self, Output = {output}>;",
indent = indent,
name = sig.func_name,
output = ret_tuple,
)
.unwrap();
}
let ret_ty = match sig.ret_kind {
ReturnKind::Plain => ret_tuple,
ReturnKind::Option => format!("Option<{}>", ret_tuple),
ReturnKind::Iterator => format!("Self::{}_iter", sig.func_name),
};
writeln!(
@@ -136,17 +143,6 @@ impl<'a> Codegen<'a> {
ret_ty = ret_ty,
)
.unwrap();
if sig.multi {
writeln!(
code,
"{indent}type {name}_iter: ContextIter<Context = Self, Output = {output}>;",
indent = indent,
name = sig.func_name,
output = ret_tuple,
)
.unwrap();
}
}
fn generate_ctx_trait(&self, code: &mut String) {
@@ -357,27 +353,29 @@ impl<'a> Codegen<'a> {
.collect::<Vec<_>>()
.join(", ");
assert_eq!(sig.ret_tys.len(), 1);
let ret = self.type_name(sig.ret_tys[0], false);
let ret = if sig.multi {
format!("impl ContextIter<Context = C, Output = {}>", ret)
} else {
ret
let ret = match sig.ret_kind {
ReturnKind::Iterator => format!("impl ContextIter<Context = C, Output = {}>", ret),
ReturnKind::Option => format!("Option<{}>", ret),
ReturnKind::Plain => ret,
};
let term_name = &self.typeenv.syms[termdata.name.index()];
writeln!(
code,
"\n// Generated as internal constructor for term {}.",
self.typeenv.syms[termdata.name.index()],
term_name,
)
.unwrap();
writeln!(
code,
"pub fn {}<C: Context>(ctx: &mut C, {}) -> Option<{}> {{",
"pub fn {}<C: Context>(ctx: &mut C, {}) -> {} {{",
sig.func_name, args, ret,
)
.unwrap();
if sig.multi {
if sig.ret_kind == ReturnKind::Iterator {
writeln!(code, "let mut returns = ConstructorVec::new();").unwrap();
}
@@ -388,18 +386,23 @@ impl<'a> Codegen<'a> {
trie,
" ",
&mut body_ctx,
sig.multi,
sig.ret_kind,
);
if !returned {
if sig.multi {
writeln!(
code,
" return Some(ContextIterWrapper::from(returns.into_iter()));"
)
.unwrap();
} else {
writeln!(code, " return None;").unwrap();
}
let ret_expr = match sig.ret_kind {
ReturnKind::Plain => Cow::from(format!(
"unreachable!(\"no rule matched for term {{}} at {{}}; should it be partial?\", {:?}, {:?})",
term_name,
termdata
.decl_pos
.pretty_print_line(&self.typeenv.filenames[..])
)),
ReturnKind::Option => Cow::from("None"),
ReturnKind::Iterator => {
Cow::from("ContextIterWrapper::from(returns.into_iter())")
}
};
write!(code, " return {};", ret_expr).unwrap();
}
writeln!(code, "}}").unwrap();
@@ -542,7 +545,7 @@ impl<'a> Codegen<'a> {
} else {
writeln!(
code,
"{}let mut it = {}(ctx, {})?;",
"{}let mut iter = {}(ctx, {});",
indent,
sig.full_name,
input_exprs.join(", "),
@@ -550,7 +553,7 @@ impl<'a> Codegen<'a> {
.unwrap();
writeln!(
code,
"{}while let Some({}) = it.next(ctx) {{",
"{}while let Some({}) = iter.next(ctx) {{",
indent, outputname,
)
.unwrap();
@@ -717,49 +720,20 @@ impl<'a> Codegen<'a> {
args = input_values.join(", ")
);
match (infallible, multi) {
(_, true) => {
writeln!(
code,
"{indent}if let Some(mut iter) = {etor_call} {{",
indent = indent,
etor_call = etor_call,
)
.unwrap();
writeln!(
code,
"{indent} while let Some({bind_pattern}) = iter.next(ctx) {{",
indent = indent,
bind_pattern = bind_pattern,
)
.unwrap();
(false, 2)
}
(false, false) => {
writeln!(
code,
"{indent}if let Some({bind_pattern}) = {etor_call} {{",
indent = indent,
bind_pattern = bind_pattern,
etor_call = etor_call,
)
.unwrap();
(false, 1)
}
(true, false) => {
writeln!(
code,
"{indent}let {bind_pattern} = {etor_call};",
indent = indent,
bind_pattern = bind_pattern,
etor_call = etor_call,
)
.unwrap();
(true, 0)
}
if multi {
writeln!(code, "{indent}let mut iter = {etor_call};").unwrap();
writeln!(
code,
"{indent}while let Some({bind_pattern}) = iter.next(ctx) {{",
)
.unwrap();
(false, 1)
} else if infallible {
writeln!(code, "{indent}let {bind_pattern} = {etor_call};").unwrap();
(true, 0)
} else {
writeln!(code, "{indent}if let Some({bind_pattern}) = {etor_call} {{").unwrap();
(false, 1)
}
}
&PatternInst::Expr {
@@ -833,7 +807,7 @@ impl<'a> Codegen<'a> {
trie: &TrieNode,
indent: &str,
ctx: &mut BodyContext,
is_multi: bool,
ret_kind: ReturnKind,
) -> bool {
log!("generate_body:\n{}", trie.pretty());
let mut returned = false;
@@ -865,17 +839,18 @@ impl<'a> Codegen<'a> {
}
assert_eq!(returns.len(), 1);
if is_multi {
writeln!(code, "{}returns.push({});", indent, returns[0].1).unwrap();
} else {
writeln!(code, "{}return Some({});", indent, returns[0].1).unwrap();
}
let (before, after) = match ret_kind {
ReturnKind::Plain => ("return ", ""),
ReturnKind::Option => ("return Some(", ")"),
ReturnKind::Iterator => ("returns.push(", ")"),
};
writeln!(code, "{}{}{}{};", indent, before, returns[0].1, after).unwrap();
for _ in 0..scopes {
writeln!(code, "{}}}", orig_indent).unwrap();
}
returned = !is_multi;
returned = ret_kind != ReturnKind::Iterator;
}
&TrieNode::Decision { ref edges } => {
@@ -930,7 +905,7 @@ impl<'a> Codegen<'a> {
&edges[i..last],
indent,
ctx,
is_multi,
ret_kind,
);
i = last;
continue;
@@ -950,7 +925,7 @@ impl<'a> Codegen<'a> {
node,
indent,
ctx,
is_multi,
ret_kind,
);
}
&TrieSymbol::Match { ref op } => {
@@ -967,7 +942,7 @@ impl<'a> Codegen<'a> {
node,
&subindent[..],
ctx,
is_multi,
ret_kind,
);
for _ in 0..new_scopes {
writeln!(code, "{}}}", indent).unwrap();
@@ -993,7 +968,7 @@ impl<'a> Codegen<'a> {
edges: &[TrieEdge],
indent: &str,
ctx: &mut BodyContext,
is_multi: bool,
ret_kind: ReturnKind,
) {
let (input, input_ty) = match &edges[0].symbol {
&TrieSymbol::Match {
@@ -1058,7 +1033,7 @@ impl<'a> Codegen<'a> {
)
.unwrap();
let subindent = format!("{} ", indent);
self.generate_body(code, depth + 1, node, &subindent, ctx, is_multi);
self.generate_body(code, depth + 1, node, &subindent, ctx, ret_kind);
writeln!(code, "{} }}", indent).unwrap();
}

View File

@@ -112,7 +112,7 @@ fn check_overlaps(terms: Vec<(TermId, trie_again::RuleSet)>, env: &TermEnv) -> E
let mut errs = Errors::default();
for (tid, ruleset) in terms {
let is_multi_ctor = match &env.terms[tid.index()].kind {
&TermKind::Decl { multi, .. } => multi,
TermKind::Decl { flags, .. } => flags.multi,
_ => false,
};
if is_multi_ctor {

View File

@@ -307,6 +307,12 @@ impl<'a> Parser<'a> {
} else {
false
};
let partial = if self.is_sym_str("partial") {
self.symbol()?;
true
} else {
false
};
let term = self.parse_ident()?;
@@ -325,6 +331,7 @@ impl<'a> Parser<'a> {
ret_ty,
pure,
multi,
partial,
pos,
})
}

View File

@@ -223,6 +223,17 @@ pub struct Term {
pub kind: TermKind,
}
/// Flags from a term's declaration with `(decl ...)`.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct TermFlags {
/// Whether the term is marked as `pure`.
pub pure: bool,
/// Whether the term is marked as `multi`.
pub multi: bool,
/// Whether the term is marked as `partial`.
pub partial: bool,
}
/// The kind of a term.
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TermKind {
@@ -234,10 +245,8 @@ pub enum TermKind {
},
/// A term declared via a `(decl ...)` form.
Decl {
/// Whether the term is marked as `pure`.
pure: bool,
/// Whether the term is marked as `multi`.
multi: bool,
/// Flags from the term's declaration.
flags: TermFlags,
/// The kind of this term's constructor, if any.
constructor_kind: Option<ConstructorKind>,
/// The kind of this term's extractor, if any.
@@ -279,6 +288,17 @@ pub enum ExtractorKind {
},
}
/// How many values a function can return.
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ReturnKind {
/// Exactly one return value.
Plain,
/// Zero or one return values.
Option,
/// Zero or more return values.
Iterator,
}
/// An external function signature.
#[derive(Clone, Debug)]
pub struct ExternalSig {
@@ -290,11 +310,8 @@ pub struct ExternalSig {
pub param_tys: Vec<TypeId>,
/// The types of this function signature's results.
pub ret_tys: Vec<TypeId>,
/// Whether this signature is infallible or not.
pub infallible: bool,
/// "Multiplicity": does the function return multiple values (via
/// an iterator)?
pub multi: bool,
/// How many values can this function return?
pub ret_kind: ReturnKind,
}
impl Term {
@@ -372,20 +389,28 @@ impl Term {
pub fn extractor_sig(&self, tyenv: &TypeEnv) -> Option<ExternalSig> {
match &self.kind {
TermKind::Decl {
multi,
flags,
extractor_kind:
Some(ExtractorKind::ExternalExtractor {
name, infallible, ..
}),
..
} => Some(ExternalSig {
func_name: tyenv.syms[name.index()].clone(),
full_name: format!("C::{}", tyenv.syms[name.index()]),
param_tys: vec![self.ret_ty],
ret_tys: self.arg_tys.clone(),
infallible: *infallible && !*multi,
multi: *multi,
}),
} => {
let ret_kind = if flags.multi {
ReturnKind::Iterator
} else if *infallible {
ReturnKind::Plain
} else {
ReturnKind::Option
};
Some(ExternalSig {
func_name: tyenv.syms[name.index()].clone(),
full_name: format!("C::{}", tyenv.syms[name.index()]),
param_tys: vec![self.ret_ty],
ret_tys: self.arg_tys.clone(),
ret_kind,
})
}
_ => None,
}
}
@@ -394,35 +419,33 @@ impl Term {
pub fn constructor_sig(&self, tyenv: &TypeEnv) -> Option<ExternalSig> {
match &self.kind {
TermKind::Decl {
constructor_kind: Some(ConstructorKind::ExternalConstructor { name }),
multi,
pure,
..
} => Some(ExternalSig {
func_name: tyenv.syms[name.index()].clone(),
full_name: format!("C::{}", tyenv.syms[name.index()]),
param_tys: self.arg_tys.clone(),
ret_tys: vec![self.ret_ty],
infallible: !pure && !*multi,
multi: *multi,
}),
TermKind::Decl {
constructor_kind: Some(ConstructorKind::InternalConstructor { .. }),
multi,
constructor_kind: Some(kind),
flags,
..
} => {
let name = format!("constructor_{}", tyenv.syms[self.name.index()]);
let (func_name, full_name) = match kind {
ConstructorKind::InternalConstructor => {
let name = format!("constructor_{}", tyenv.syms[self.name.index()]);
(name.clone(), name)
}
ConstructorKind::ExternalConstructor { name } => (
tyenv.syms[name.index()].clone(),
format!("C::{}", tyenv.syms[name.index()]),
),
};
let ret_kind = if flags.multi {
ReturnKind::Iterator
} else if flags.partial {
ReturnKind::Option
} else {
ReturnKind::Plain
};
Some(ExternalSig {
func_name: name.clone(),
full_name: name,
func_name,
full_name,
param_tys: self.arg_tys.clone(),
ret_tys: vec![self.ret_ty],
// Internal constructors are always fallible, even
// if not pure, because ISLE allows partial
// matching at the toplevel (an entry point can
// fail to rewrite).
infallible: false,
multi: *multi,
ret_kind,
})
}
_ => None,
@@ -625,7 +648,7 @@ impl Pattern {
panic!("Should have been expanded away")
}
TermKind::Decl {
multi,
flags,
extractor_kind: Some(ExtractorKind::ExternalExtractor { infallible, .. }),
..
} => {
@@ -638,8 +661,8 @@ impl Pattern {
termdata.ret_ty,
output_tys,
term,
*infallible && !*multi,
*multi,
*infallible && !flags.multi,
flags.multi,
)
}
};
@@ -737,30 +760,16 @@ impl Expr {
visitor.add_create_variant(arg_values_tys, ty, *variant)
}
TermKind::Decl {
constructor_kind: Some(ConstructorKind::InternalConstructor),
multi,
constructor_kind: Some(_),
flags,
..
} => {
visitor.add_construct(
arg_values_tys,
ty,
term,
/* infallible = */ false,
*multi,
)
}
TermKind::Decl {
constructor_kind: Some(ConstructorKind::ExternalConstructor { .. }),
pure,
multi,
..
} => {
visitor.add_construct(
arg_values_tys,
ty,
term,
/* infallible = */ !pure,
*multi,
/* infallible = */ !flags.partial,
flags.multi,
)
}
TermKind::Decl {
@@ -1167,6 +1176,13 @@ impl TermEnv {
);
}
if decl.multi && decl.partial {
tyenv.report_error(
decl.pos,
format!("Term '{}' can't be both multi and partial", decl.term.0),
);
}
let arg_tys = decl
.arg_tys
.iter()
@@ -1196,6 +1212,11 @@ impl TermEnv {
let tid = TermId(self.terms.len());
self.term_map.insert(name, tid);
let flags = TermFlags {
pure: decl.pure,
multi: decl.multi,
partial: decl.partial,
};
self.terms.push(Term {
id: tid,
decl_pos: decl.pos,
@@ -1203,10 +1224,9 @@ impl TermEnv {
arg_tys,
ret_ty,
kind: TermKind::Decl {
flags,
constructor_kind: None,
extractor_kind: None,
pure: decl.pure,
multi: decl.multi,
},
});
}
@@ -1364,12 +1384,12 @@ impl TermEnv {
continue;
}
TermKind::Decl {
multi,
flags,
extractor_kind,
..
} => match extractor_kind {
None => {
if *multi {
if flags.multi {
tyenv.report_error(
ext.pos,
"A term declared with `multi` cannot have an internal extractor.".to_string());
@@ -1658,8 +1678,8 @@ impl TermEnv {
let termdata = &self.terms[root_term.index()];
let pure = match &termdata.kind {
&TermKind::Decl { pure, .. } => pure,
let flags = match &termdata.kind {
TermKind::Decl { flags, .. } => flags,
_ => {
tyenv.report_error(
pos,
@@ -1676,14 +1696,17 @@ impl TermEnv {
let iflets = rule
.iflets
.iter()
.filter_map(|iflet| self.translate_iflet(tyenv, iflet, &mut bindings))
.filter_map(|iflet| {
self.translate_iflet(tyenv, iflet, &mut bindings, flags)
})
.collect();
let rhs = unwrap_or_continue!(self.translate_expr(
tyenv,
&rule.expr,
Some(termdata.ret_ty),
&mut bindings,
pure,
flags,
/* on_lhs */ false,
));
let rid = RuleId(self.rules.len());
@@ -2058,7 +2081,8 @@ impl TermEnv {
expr: &ast::Expr,
ty: Option<TypeId>,
bindings: &mut Bindings,
pure: bool,
root_flags: &TermFlags,
on_lhs: bool,
) -> Option<Expr> {
log!("translate_expr: {:?}", expr);
match expr {
@@ -2101,7 +2125,14 @@ impl TermEnv {
if let Some(expanded_expr) =
self.maybe_implicit_convert_expr(tyenv, expr, ret_ty, ty.unwrap())
{
return self.translate_expr(tyenv, &expanded_expr, ty, bindings, pure);
return self.translate_expr(
tyenv,
&expanded_expr,
ty,
bindings,
root_flags,
on_lhs,
);
}
tyenv.report_error(
@@ -2116,9 +2147,11 @@ impl TermEnv {
ret_ty
};
// Check that the term's constructor is pure.
if pure {
if let TermKind::Decl { pure: false, .. } = &termdata.kind {
if let TermKind::Decl { flags, .. } = &termdata.kind {
// On the left-hand side of a rule or in a pure term, only pure terms may be
// used.
let pure_required = on_lhs || root_flags.pure;
if pure_required && !flags.pure {
tyenv.report_error(
pos,
format!(
@@ -2127,6 +2160,36 @@ impl TermEnv {
),
);
}
// Multi-terms may only be used inside other multi-terms.
if !root_flags.multi && flags.multi {
tyenv.report_error(
pos,
format!(
"Used multi-term '{}' but this rule is not in a multi-term",
sym.0
),
);
}
// Partial terms may always be used on the left-hand side of a rule. On the
// right-hand side they may only be used inside other partial terms.
let partial_allowed = on_lhs || root_flags.partial;
if !partial_allowed && flags.partial {
tyenv.report_error(
pos,
format!(
"Rule can't use partial constructor '{}' on RHS; \
try moving it to if-let{}",
sym.0,
if root_flags.multi {
""
} else {
" or make this rule's term partial too"
}
),
);
}
}
termdata.check_args_count(args, tyenv, pos, sym);
@@ -2136,7 +2199,7 @@ impl TermEnv {
.iter()
.zip(termdata.arg_tys.iter())
.filter_map(|(arg, &arg_ty)| {
self.translate_expr(tyenv, arg, Some(arg_ty), bindings, pure)
self.translate_expr(tyenv, arg, Some(arg_ty), bindings, root_flags, on_lhs)
})
.collect();
@@ -2159,7 +2222,14 @@ impl TermEnv {
if let Some(expanded_expr) =
self.maybe_implicit_convert_expr(tyenv, expr, bv.ty, ty.unwrap())
{
return self.translate_expr(tyenv, &expanded_expr, ty, bindings, pure);
return self.translate_expr(
tyenv,
&expanded_expr,
ty,
bindings,
root_flags,
on_lhs,
);
}
tyenv.report_error(
@@ -2251,7 +2321,8 @@ impl TermEnv {
&def.val,
Some(tid),
bindings,
pure
root_flags,
on_lhs,
)));
// Bind the var with the given type.
@@ -2260,7 +2331,8 @@ impl TermEnv {
}
// Evaluate the body, expecting the type of the overall let-expr.
let body = Box::new(self.translate_expr(tyenv, body, ty, bindings, pure)?);
let body =
Box::new(self.translate_expr(tyenv, body, ty, bindings, root_flags, on_lhs)?);
let body_ty = body.ty();
// Pop the bindings.
@@ -2280,10 +2352,18 @@ impl TermEnv {
tyenv: &mut TypeEnv,
iflet: &ast::IfLet,
bindings: &mut Bindings,
root_flags: &TermFlags,
) -> Option<IfLet> {
// Translate the expr first. Ensure it's pure.
let rhs =
self.translate_expr(tyenv, &iflet.expr, None, bindings, /* pure = */ true)?;
// Translate the expr first. The `if-let` and `if` forms are part of the left-hand side of
// the rule.
let rhs = self.translate_expr(
tyenv,
&iflet.expr,
None,
bindings,
root_flags,
/* on_lhs */ true,
)?;
let ty = rhs.ty();
let (lhs, _lhs_ty) = self.translate_pattern(tyenv, &iflet.pattern, Some(ty), bindings)?;