egraph support: rewrite to work in terms of CLIF data structures. (#5382)

* egraph support: rewrite to work in terms of CLIF data structures.

This work rewrites the "egraph"-based optimization framework in
Cranelift to operate on aegraphs (acyclic egraphs) represented in the
CLIF itself rather than as a separate data structure to which and from
which we translate the CLIF.

The basic idea is to add a new kind of value, a "union", that is like an
alias but refers to two other values rather than one.  This allows us to
represent an eclass of enodes (values) as a tree. The union node allows
for a value to have *multiple representations*: either constituent value
could be used, and (in well-formed CLIF produced by correct
optimization rules) they must be equivalent.

Like the old egraph infrastructure, we take advantage of acyclicity and
eager rule application to do optimization in a single pass. Like before,
we integrate GVN (during the optimization pass) and LICM (during
elaboration).

Unlike the old egraph infrastructure, everything stays in the
DataFlowGraph. "Pure" enodes are represented as instructions that have
values attached, but that are not placed into the function layout. When
entering "egraph" form, we remove them from the layout while optimizing.
When leaving "egraph" form, during elaboration, we can place an
instruction back into the layout the first time we elaborate the enode;
if we elaborate it more than once, we clone the instruction.

The implementation performs two passes overall:

- One, a forward pass in RPO (to see defs before uses), that (i) removes
  "pure" instructions from the layout and (ii) optimizes as it goes. As
  before, we eagerly optimize, so we form the entire union of optimized
  forms of a value before we see any uses of that value. This lets us
  rewrite uses to use the most "up-to-date" form of the value and
  canonicalize and optimize that form.

  The eager rewriting and acyclic representation make each other work
  (we could not eagerly rewrite if there were cycles; and acyclicity
  does not miss optimization opportunities only because the first time
  we introduce a value, we immediately produce its "best" form). This
  design choice is also what allows us to avoid the "parent pointers"
  and fixpoint loop of traditional egraphs.

  This forward optimization pass keeps a scoped hashmap to "intern"
  nodes (thus performing GVN), and also interleaves on a per-instruction
  level with alias analysis. The interleaving with alias analysis allows
  alias analysis to see the most optimized form of each address (so it
  can see equivalences), and allows the next value to see any
  equivalences (reuses of loads or stored values) that alias analysis
  uncovers.

- Two, a forward pass in domtree preorder, that "elaborates" pure enodes
  back into the layout, possibly in multiple places if needed. This
  tracks the loop nest and hoists nodes as needed, performing LICM as it
  goes. Note that by doing this in forward order, we avoid the
  "fixpoint" that traditional LICM needs: we hoist a def before its
  uses, so when we place a node, we place it in the right place the
  first time rather than moving later.

This PR replaces the old (a)egraph implementation. It removes both the
cranelift-egraph crate and the logic in cranelift-codegen that uses it.

On `spidermonkey.wasm` running a simple recursive Fibonacci
microbenchmark, this work shows 5.5% compile-time reduction and 7.7%
runtime improvement (speedup).

Most of this implementation was done in (very productive) pair
programming sessions with Jamey Sharp, thus:

Co-authored-by: Jamey Sharp <jsharp@fastly.com>

* Review feedback.

* Review feedback.

* Review feedback.

* Bugfix: cprop rule: `(x + k1) - k2` becomes `x - (k2 - k1)`, not `x - (k1 - k2)`.

Co-authored-by: Jamey Sharp <jsharp@fastly.com>
This commit is contained in:
Chris Fallin
2022-12-06 14:58:57 -08:00
committed by GitHub
parent 08d44e3746
commit f980defe17
42 changed files with 1890 additions and 3884 deletions

View File

@@ -0,0 +1,97 @@
//! Cost functions for egraph representation.
use crate::ir::Opcode;
/// A cost of computing some value in the program.
///
/// Costs are measured in an arbitrary union that we represent in a
/// `u32`. The ordering is meant to be meaningful, but the value of a
/// single unit is arbitrary (and "not to scale"). We use a collection
/// of heuristics to try to make this approximation at least usable.
///
/// We start by defining costs for each opcode (see `pure_op_cost`
/// below). The cost of computing some value, initially, is the cost
/// of its opcode, plus the cost of computing its inputs.
///
/// We then adjust the cost according to loop nests: for each
/// loop-nest level, we multiply by 1024. Because we only have 32
/// bits, we limit this scaling to a loop-level of two (i.e., multiply
/// by 2^20 ~= 1M).
///
/// Arithmetic on costs is always saturating: we don't want to wrap
/// around and return to a tiny cost when adding the costs of two very
/// expensive operations. It is better to approximate and lose some
/// precision than to lose the ordering by wrapping.
///
/// Finally, we reserve the highest value, `u32::MAX`, as a sentinel
/// that means "infinite". This is separate from the finite costs and
/// not reachable by doing arithmetic on them (even when overflowing)
/// -- we saturate just *below* infinity. (This is done by the
/// `finite()` method.) An infinite cost is used to represent a value
/// that cannot be computed, or otherwise serve as a sentinel when
/// performing search for the lowest-cost representation of a value.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct Cost(u32);
impl Cost {
pub(crate) fn at_level(&self, loop_level: usize) -> Cost {
let loop_level = std::cmp::min(2, loop_level);
let multiplier = 1u32 << ((10 * loop_level) as u32);
Cost(self.0.saturating_mul(multiplier)).finite()
}
pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Cost(u32::MAX)
}
pub(crate) fn zero() -> Cost {
Cost(0)
}
/// Clamp this cost at a "finite" value. Can be used in
/// conjunction with saturating ops to avoid saturating into
/// `infinity()`.
fn finite(self) -> Cost {
Cost(std::cmp::min(u32::MAX - 1, self.0))
}
}
impl std::default::Default for Cost {
fn default() -> Cost {
Cost::zero()
}
}
impl std::ops::Add<Cost> for Cost {
type Output = Cost;
fn add(self, other: Cost) -> Cost {
Cost(self.0.saturating_add(other.0)).finite()
}
}
/// Return the cost of a *pure* opcode. Caller is responsible for
/// checking that the opcode came from an instruction that satisfies
/// `inst_predicates::is_pure_for_egraph()`.
pub(crate) fn pure_op_cost(op: Opcode) -> Cost {
match op {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost(0),
// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost(1)
}
// "Simple" arithmetic.
Opcode::Iadd
| Opcode::Isub
| Opcode::Band
| Opcode::BandNot
| Opcode::Bor
| Opcode::BorNot
| Opcode::Bxor
| Opcode::BxorNot
| Opcode::Bnot => Cost(2),
// Everything else (pure.)
_ => Cost(3),
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,366 +0,0 @@
//! Node definition for EGraph representation.
use super::PackedMemoryState;
use crate::ir::{Block, DataFlowGraph, InstructionImms, Opcode, RelSourceLoc, Type};
use crate::loop_analysis::LoopLevel;
use cranelift_egraph::{CtxEq, CtxHash, Id, Language, UnionFind};
use cranelift_entity::{EntityList, ListPool};
use std::hash::{Hash, Hasher};
#[derive(Debug)]
pub enum Node {
/// A blockparam. Effectively an input/root; does not refer to
/// predecessors' branch arguments, because this would create
/// cycles.
Param {
/// CLIF block this param comes from.
block: Block,
/// Index of blockparam within block.
index: u32,
/// Type of the value.
ty: Type,
/// The loop level of this Param.
loop_level: LoopLevel,
},
/// A CLIF instruction that is pure (has no side-effects). Not
/// tied to any location; we will compute a set of locations at
/// which to compute this node during lowering back out of the
/// egraph.
Pure {
/// The instruction data, without SSA values.
op: InstructionImms,
/// eclass arguments to the operator.
args: EntityList<Id>,
/// Type of result, if one.
ty: Type,
/// Number of results.
arity: u16,
},
/// A CLIF instruction that has side-effects or is otherwise not
/// representable by `Pure`.
Inst {
/// The instruction data, without SSA values.
op: InstructionImms,
/// eclass arguments to the operator.
args: EntityList<Id>,
/// Type of result, if one.
ty: Type,
/// Number of results.
arity: u16,
/// The source location to preserve.
srcloc: RelSourceLoc,
/// The loop level of this Inst.
loop_level: LoopLevel,
},
/// A projection of one result of an `Inst` or `Pure`.
Result {
/// `Inst` or `Pure` node.
value: Id,
/// Index of the result we want.
result: usize,
/// Type of the value.
ty: Type,
},
/// A load instruction. Nominally a side-effecting `Inst` (and
/// included in the list of side-effecting roots so it will always
/// be elaborated), but represented as a distinct kind of node so
/// that we can leverage deduplication to do
/// redundant-load-elimination for free (and make store-to-load
/// forwarding much easier).
Load {
// -- identity depends on:
/// The original load operation. Must have one argument, the
/// address.
op: InstructionImms,
/// The type of the load result.
ty: Type,
/// Address argument. Actual address has an offset, which is
/// included in `op` (and thus already considered as part of
/// the key).
addr: Id,
/// The abstract memory state that this load accesses.
mem_state: PackedMemoryState,
// -- not included in dedup key:
/// Source location, for traps. Not included in Eq/Hash.
srcloc: RelSourceLoc,
},
}
impl Node {
pub(crate) fn is_non_pure(&self) -> bool {
match self {
Node::Inst { .. } | Node::Load { .. } => true,
_ => false,
}
}
}
/// Shared pools for type and id lists in nodes.
pub struct NodeCtx {
/// Arena for arg eclass-ID lists.
pub args: ListPool<Id>,
}
impl NodeCtx {
pub(crate) fn with_capacity_for_dfg(dfg: &DataFlowGraph) -> Self {
let n_args = dfg.value_lists.capacity();
Self {
args: ListPool::with_capacity(n_args),
}
}
}
impl NodeCtx {
fn ids_eq(&self, a: &EntityList<Id>, b: &EntityList<Id>, uf: &mut UnionFind) -> bool {
let a = a.as_slice(&self.args);
let b = b.as_slice(&self.args);
a.len() == b.len() && a.iter().zip(b.iter()).all(|(&a, &b)| uf.equiv_id_mut(a, b))
}
fn hash_ids<H: Hasher>(&self, a: &EntityList<Id>, hash: &mut H, uf: &mut UnionFind) {
let a = a.as_slice(&self.args);
for &id in a {
uf.hash_id_mut(hash, id);
}
}
}
impl CtxEq<Node, Node> for NodeCtx {
fn ctx_eq(&self, a: &Node, b: &Node, uf: &mut UnionFind) -> bool {
match (a, b) {
(
&Node::Param {
block,
index,
ty,
loop_level: _,
},
&Node::Param {
block: other_block,
index: other_index,
ty: other_ty,
loop_level: _,
},
) => block == other_block && index == other_index && ty == other_ty,
(
&Node::Result { value, result, ty },
&Node::Result {
value: other_value,
result: other_result,
ty: other_ty,
},
) => uf.equiv_id_mut(value, other_value) && result == other_result && ty == other_ty,
(
&Node::Pure {
ref op,
ref args,
ty,
arity: _,
},
&Node::Pure {
op: ref other_op,
args: ref other_args,
ty: other_ty,
arity: _,
},
) => *op == *other_op && self.ids_eq(args, other_args, uf) && ty == other_ty,
(
&Node::Inst { ref args, .. },
&Node::Inst {
args: ref other_args,
..
},
) => self.ids_eq(args, other_args, uf),
(
&Node::Load {
ref op,
ty,
addr,
mem_state,
..
},
&Node::Load {
op: ref other_op,
ty: other_ty,
addr: other_addr,
mem_state: other_mem_state,
// Explicitly exclude: `inst` and `srcloc`. We
// want loads to merge if identical in
// opcode/offset, address expression, and last
// store (this does implicit
// redundant-load-elimination.)
//
// Note however that we *do* include `ty` (the
// type) and match on that: we otherwise would
// have no way of disambiguating loads of
// different widths to the same address.
..
},
) => {
op == other_op
&& ty == other_ty
&& uf.equiv_id_mut(addr, other_addr)
&& mem_state == other_mem_state
}
_ => false,
}
}
}
impl CtxHash<Node> for NodeCtx {
fn ctx_hash(&self, value: &Node, uf: &mut UnionFind) -> u64 {
let mut state = crate::fx::FxHasher::default();
std::mem::discriminant(value).hash(&mut state);
match value {
&Node::Param {
block,
index,
ty: _,
loop_level: _,
} => {
block.hash(&mut state);
index.hash(&mut state);
}
&Node::Result {
value,
result,
ty: _,
} => {
uf.hash_id_mut(&mut state, value);
result.hash(&mut state);
}
&Node::Pure {
ref op,
ref args,
ty,
arity: _,
} => {
op.hash(&mut state);
self.hash_ids(args, &mut state, uf);
ty.hash(&mut state);
}
&Node::Inst { ref args, .. } => {
self.hash_ids(args, &mut state, uf);
}
&Node::Load {
ref op,
ty,
addr,
mem_state,
..
} => {
op.hash(&mut state);
ty.hash(&mut state);
uf.hash_id_mut(&mut state, addr);
mem_state.hash(&mut state);
}
}
state.finish()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub(crate) struct Cost(u32);
impl Cost {
pub(crate) fn at_level(&self, loop_level: LoopLevel) -> Cost {
let loop_level = std::cmp::min(2, loop_level.level());
let multiplier = 1u32 << ((10 * loop_level) as u32);
Cost(self.0.saturating_mul(multiplier)).finite()
}
pub(crate) fn infinity() -> Cost {
// 2^32 - 1 is, uh, pretty close to infinite... (we use `Cost`
// only for heuristics and always saturate so this suffices!)
Cost(u32::MAX)
}
pub(crate) fn zero() -> Cost {
Cost(0)
}
/// Clamp this cost at a "finite" value. Can be used in
/// conjunction with saturating ops to avoid saturating into
/// `infinity()`.
fn finite(self) -> Cost {
Cost(std::cmp::min(u32::MAX - 1, self.0))
}
}
impl std::default::Default for Cost {
fn default() -> Cost {
Cost::zero()
}
}
impl std::ops::Add<Cost> for Cost {
type Output = Cost;
fn add(self, other: Cost) -> Cost {
Cost(self.0.saturating_add(other.0)).finite()
}
}
pub(crate) fn op_cost(op: &InstructionImms) -> Cost {
match op.opcode() {
// Constants.
Opcode::Iconst | Opcode::F32const | Opcode::F64const => Cost(0),
// Extends/reduces.
Opcode::Uextend | Opcode::Sextend | Opcode::Ireduce | Opcode::Iconcat | Opcode::Isplit => {
Cost(1)
}
// "Simple" arithmetic.
Opcode::Iadd
| Opcode::Isub
| Opcode::Band
| Opcode::BandNot
| Opcode::Bor
| Opcode::BorNot
| Opcode::Bxor
| Opcode::BxorNot
| Opcode::Bnot => Cost(2),
// Everything else.
_ => Cost(3),
}
}
impl Language for NodeCtx {
type Node = Node;
fn children<'a>(&'a self, node: &'a Node) -> &'a [Id] {
match node {
Node::Param { .. } => &[],
Node::Pure { args, .. } | Node::Inst { args, .. } => args.as_slice(&self.args),
Node::Load { addr, .. } => std::slice::from_ref(addr),
Node::Result { value, .. } => std::slice::from_ref(value),
}
}
fn children_mut<'a>(&'a mut self, node: &'a mut Node) -> &'a mut [Id] {
match node {
Node::Param { .. } => &mut [],
Node::Pure { args, .. } | Node::Inst { args, .. } => args.as_mut_slice(&mut self.args),
Node::Load { addr, .. } => std::slice::from_mut(addr),
Node::Result { value, .. } => std::slice::from_mut(value),
}
}
fn needs_dedup(&self, node: &Node) -> bool {
match node {
Node::Pure { .. } | Node::Load { .. } => true,
_ => false,
}
}
}
#[cfg(test)]
mod test {
#[test]
#[cfg(target_pointer_width = "64")]
fn node_size() {
use super::*;
assert_eq!(std::mem::size_of::<InstructionImms>(), 16);
assert_eq!(std::mem::size_of::<Node>(), 32);
}
}

View File

@@ -1,293 +0,0 @@
//! Last-store tracking via alias analysis.
//!
//! We partition memory state into several *disjoint pieces* of
//! "abstract state". There are a finite number of such pieces:
//! currently, we call them "heap", "table", "vmctx", and "other". Any
//! given address in memory belongs to exactly one disjoint piece.
//!
//! One never tracks which piece a concrete address belongs to at
//! runtime; this is a purely static concept. Instead, all
//! memory-accessing instructions (loads and stores) are labeled with
//! one of these four categories in the `MemFlags`. It is forbidden
//! for a load or store to access memory under one category and a
//! later load or store to access the same memory under a different
//! category. This is ensured to be true by construction during
//! frontend translation into CLIF and during legalization.
//!
//! Given that this non-aliasing property is ensured by the producer
//! of CLIF, we can compute a *may-alias* property: one load or store
//! may-alias another load or store if both access the same category
//! of abstract state.
//!
//! The "last store" pass helps to compute this aliasing: we perform a
//! fixpoint analysis to track the last instruction that *might have*
//! written to a given part of abstract state. We also track the block
//! containing this store.
//!
//! We can't say for sure that the "last store" *did* actually write
//! that state, but we know for sure that no instruction *later* than
//! it (up to the current instruction) did. However, we can get a
//! must-alias property from this: if at a given load or store, we
//! look backward to the "last store", *AND* we find that it has
//! exactly the same address expression and value type, then we know
//! that the current instruction's access *must* be to the same memory
//! location.
//!
//! To get this must-alias property, we leverage the node
//! hashconsing. We design the Eq/Hash (node identity relation
//! definition) of the `Node` struct so that all loads with (i) the
//! same "last store", and (ii) the same address expression, and (iii)
//! the same opcode-and-offset, will deduplicate (the first will be
//! computed, and the later ones will use the same value). Furthermore
//! we have an optimization that rewrites a load into the stored value
//! of the last store *if* the last store has the same address
//! expression and constant offset.
//!
//! This gives us two optimizations, "redundant load elimination" and
//! "store-to-load forwarding".
//!
//! In theory we could also do *dead-store elimination*, where if a
//! store overwrites a value earlier written by another store, *and*
//! if no other load/store to the abstract state category occurred,
//! *and* no other trapping instruction occurred (at which point we
//! need an up-to-date memory state because post-trap-termination
//! memory state can be observed), *and* we can prove the original
//! store could not have trapped, then we can eliminate the original
//! store. Because this is so complex, and the conditions for doing it
//! correctly when post-trap state must be correct likely reduce the
//! potential benefit, we don't yet do this.
use crate::flowgraph::ControlFlowGraph;
use crate::fx::{FxHashMap, FxHashSet};
use crate::inst_predicates::has_memory_fence_semantics;
use crate::ir::{Block, Function, Inst, InstructionData, MemFlags, Opcode};
use crate::trace;
use cranelift_entity::{EntityRef, SecondaryMap};
use smallvec::{smallvec, SmallVec};
/// For a given program point, the vector of last-store instruction
/// indices for each disjoint category of abstract state.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
struct LastStores {
heap: MemoryState,
table: MemoryState,
vmctx: MemoryState,
other: MemoryState,
}
/// State of memory seen by a load.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub enum MemoryState {
/// State at function entry: nothing is known (but it is one
/// consistent value, so two loads from "entry" state at the same
/// address will still provide the same result).
#[default]
Entry,
/// State just after a store by the given instruction. The
/// instruction is a store from which we can forward.
Store(Inst),
/// State just before the given instruction. Used for abstract
/// value merges at merge-points when we cannot name a single
/// producing site.
BeforeInst(Inst),
/// State just after the given instruction. Used when the
/// instruction may update the associated state, but is not a
/// store whose value we can cleanly forward. (E.g., perhaps a
/// barrier of some sort.)
AfterInst(Inst),
}
/// Memory state index, packed into a u32.
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PackedMemoryState(u32);
impl From<MemoryState> for PackedMemoryState {
fn from(state: MemoryState) -> Self {
match state {
MemoryState::Entry => Self(0),
MemoryState::Store(i) => Self(1 | (i.index() as u32) << 2),
MemoryState::BeforeInst(i) => Self(2 | (i.index() as u32) << 2),
MemoryState::AfterInst(i) => Self(3 | (i.index() as u32) << 2),
}
}
}
impl PackedMemoryState {
/// Does this memory state refer to a specific store instruction?
pub fn as_store(&self) -> Option<Inst> {
if self.0 & 3 == 1 {
Some(Inst::from_bits(self.0 >> 2))
} else {
None
}
}
}
impl LastStores {
fn update(&mut self, func: &Function, inst: Inst) {
let opcode = func.dfg[inst].opcode();
if has_memory_fence_semantics(opcode) {
self.heap = MemoryState::AfterInst(inst);
self.table = MemoryState::AfterInst(inst);
self.vmctx = MemoryState::AfterInst(inst);
self.other = MemoryState::AfterInst(inst);
} else if opcode.can_store() {
if let Some(memflags) = func.dfg[inst].memflags() {
*self.for_flags(memflags) = MemoryState::Store(inst);
} else {
self.heap = MemoryState::AfterInst(inst);
self.table = MemoryState::AfterInst(inst);
self.vmctx = MemoryState::AfterInst(inst);
self.other = MemoryState::AfterInst(inst);
}
}
}
fn for_flags(&mut self, memflags: MemFlags) -> &mut MemoryState {
if memflags.heap() {
&mut self.heap
} else if memflags.table() {
&mut self.table
} else if memflags.vmctx() {
&mut self.vmctx
} else {
&mut self.other
}
}
fn meet_from(&mut self, other: &LastStores, loc: Inst) {
let meet = |a: MemoryState, b: MemoryState| -> MemoryState {
match (a, b) {
(a, b) if a == b => a,
_ => MemoryState::BeforeInst(loc),
}
};
self.heap = meet(self.heap, other.heap);
self.table = meet(self.table, other.table);
self.vmctx = meet(self.vmctx, other.vmctx);
self.other = meet(self.other, other.other);
}
}
/// An alias-analysis pass.
pub struct AliasAnalysis {
/// Last-store instruction (or none) for a given load. Use a hash map
/// instead of a `SecondaryMap` because this is sparse.
load_mem_state: FxHashMap<Inst, PackedMemoryState>,
}
impl AliasAnalysis {
/// Perform an alias analysis pass.
pub fn new(func: &Function, cfg: &ControlFlowGraph) -> AliasAnalysis {
log::trace!("alias analysis: input is:\n{:?}", func);
let block_input = Self::compute_block_input_states(func, cfg);
let load_mem_state = Self::compute_load_last_stores(func, block_input);
AliasAnalysis { load_mem_state }
}
fn compute_block_input_states(
func: &Function,
cfg: &ControlFlowGraph,
) -> SecondaryMap<Block, Option<LastStores>> {
let mut block_input = SecondaryMap::with_capacity(func.dfg.num_blocks());
let mut worklist: SmallVec<[Block; 16]> = smallvec![];
let mut worklist_set = FxHashSet::default();
let entry = func.layout.entry_block().unwrap();
worklist.push(entry);
worklist_set.insert(entry);
block_input[entry] = Some(LastStores::default());
while let Some(block) = worklist.pop() {
worklist_set.remove(&block);
let state = block_input[block].clone().unwrap();
trace!("alias analysis: input to {} is {:?}", block, state);
let state = func
.layout
.block_insts(block)
.fold(state, |mut state, inst| {
state.update(func, inst);
trace!("after {}: state is {:?}", inst, state);
state
});
for succ in cfg.succ_iter(block) {
let succ_first_inst = func.layout.first_inst(succ).unwrap();
let succ_state = &mut block_input[succ];
let old = succ_state.clone();
if let Some(succ_state) = succ_state.as_mut() {
succ_state.meet_from(&state, succ_first_inst);
} else {
*succ_state = Some(state);
};
let updated = *succ_state != old;
if updated && worklist_set.insert(succ) {
worklist.push(succ);
}
}
}
block_input
}
fn compute_load_last_stores(
func: &Function,
block_input: SecondaryMap<Block, Option<LastStores>>,
) -> FxHashMap<Inst, PackedMemoryState> {
let mut load_mem_state = FxHashMap::default();
load_mem_state.reserve(func.dfg.num_insts() / 8);
for block in func.layout.blocks() {
let mut state = block_input[block].clone().unwrap();
for inst in func.layout.block_insts(block) {
trace!(
"alias analysis: scanning at {} with state {:?} ({:?})",
inst,
state,
func.dfg[inst],
);
// N.B.: we match `Load` specifically, and not any
// other kinds of loads (or any opcode such that
// `opcode.can_load()` returns true), because some
// "can load" instructions actually have very
// different semantics (are not just a load of a
// particularly-typed value). For example, atomic
// (load/store, RMW, CAS) instructions "can load" but
// definitely should not participate in store-to-load
// forwarding or redundant-load elimination. Our goal
// here is to provide a `MemoryState` just for plain
// old loads whose semantics we can completely reason
// about.
if let InstructionData::Load {
opcode: Opcode::Load,
flags,
..
} = func.dfg[inst]
{
let mem_state = *state.for_flags(flags);
trace!(
"alias analysis: at {}: load with mem_state {:?}",
inst,
mem_state,
);
load_mem_state.insert(inst, mem_state.into());
}
state.update(func, inst);
}
}
load_mem_state
}
/// Get the state seen by a load, if any.
pub fn get_state_for_load(&self, inst: Inst) -> Option<PackedMemoryState> {
self.load_mem_state.get(&inst).copied()
}
}