diff --git a/cranelift/codegen/meta/src/cdsl/typevar.rs b/cranelift/codegen/meta/src/cdsl/typevar.rs index d144eab48f..41d94d67f4 100644 --- a/cranelift/codegen/meta/src/cdsl/typevar.rs +++ b/cranelift/codegen/meta/src/cdsl/typevar.rs @@ -1,4 +1,7 @@ -use std::collections::BTreeSet; +use std::cell::RefCell; +use std::collections::{BTreeSet, HashSet}; +use std::fmt; +use std::hash; use std::iter::FromIterator; use std::ops; use std::rc::Rc; @@ -301,7 +304,7 @@ macro_rules! num_set { }; } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash)] pub struct TypeSet { pub lanes: NumSet, pub ints: NumSet, @@ -331,7 +334,7 @@ impl TypeSet { } /// Return the number of concrete types represented by this typeset. - fn size(&self) -> usize { + pub fn size(&self) -> usize { self.lanes.len() * (self.ints.len() + self.floats.len() + self.bools.len() + self.bitvecs.len()) + self.specials.len() @@ -486,6 +489,175 @@ impl TypeSet { assert_eq!(types.len(), 1); return types.remove(0); } + + /// Return the inverse image of self across the derived function func. + fn preimage(&self, func: DerivedFunc) -> TypeSet { + if self.size() == 0 { + // The inverse of the empty set is itself. + return self.clone(); + } + + match func { + DerivedFunc::LaneOf => { + let mut copy = self.clone(); + copy.bitvecs = NumSet::new(); + copy.lanes = + NumSet::from_iter((0..MAX_LANES.trailing_zeros() + 1).map(|i| u16::pow(2, i))); + copy + } + DerivedFunc::AsBool => { + let mut copy = self.clone(); + copy.bitvecs = NumSet::new(); + if self.bools.contains(&1) { + copy.ints = NumSet::from_iter(vec![8, 16, 32, 64]); + copy.floats = NumSet::from_iter(vec![32, 64]); + } else { + copy.ints = &self.bools - &NumSet::from_iter(vec![1]); + copy.floats = &self.bools & &NumSet::from_iter(vec![32, 64]); + // If b1 is not in our typeset, than lanes=1 cannot be in the pre-image, as + // as_bool() of scalars is always b1. + copy.lanes = &self.lanes - &NumSet::from_iter(vec![1]); + } + copy + } + DerivedFunc::HalfWidth => self.double_width(), + DerivedFunc::DoubleWidth => self.half_width(), + DerivedFunc::HalfVector => self.double_vector(), + DerivedFunc::DoubleVector => self.half_vector(), + DerivedFunc::ToBitVec => { + let all_lanes = range_to_set(Some(1..MAX_LANES)); + let all_ints = range_to_set(Some(8..MAX_BITS)); + let all_floats = range_to_set(Some(32..64)); + let all_bools = range_to_set(Some(1..MAX_BITS)); + + let mut lanes = range_to_set(Some(1..MAX_LANES)); + let mut ints = range_to_set(Some(8..MAX_BITS)); + let mut floats = range_to_set(Some(32..64)); + let mut bools = range_to_set(Some(1..MAX_BITS)); + + for &l in &all_lanes { + for &i in &all_ints { + if self.bitvecs.contains(&(i * l)) { + lanes.insert(l); + ints.insert(i); + } + } + for &f in &all_floats { + if self.bitvecs.contains(&(f * l)) { + lanes.insert(l); + floats.insert(f); + } + } + for &b in &all_bools { + if self.bitvecs.contains(&(b * l)) { + lanes.insert(l); + bools.insert(b); + } + } + } + + let bitvecs = NumSet::new(); + let specials = Vec::new(); + TypeSet::new(lanes, ints, floats, bools, bitvecs, specials) + } + } + } + + pub fn inplace_intersect_with(&mut self, other: &TypeSet) { + self.lanes = &self.lanes & &other.lanes; + self.ints = &self.ints & &other.ints; + self.floats = &self.floats & &other.floats; + self.bools = &self.bools & &other.bools; + self.bitvecs = &self.bitvecs & &other.bitvecs; + + let mut new_specials = Vec::new(); + for spec in &self.specials { + if let Some(spec) = other.specials.iter().find(|&other_spec| other_spec == spec) { + new_specials.push(*spec); + } + } + self.specials = new_specials; + } + + pub fn is_subset(&self, other: &TypeSet) -> bool { + self.lanes.is_subset(&other.lanes) + && self.ints.is_subset(&other.ints) + && self.floats.is_subset(&other.floats) + && self.bools.is_subset(&other.bools) + && self.bitvecs.is_subset(&other.bitvecs) + && { + let specials: HashSet = HashSet::from_iter(self.specials.clone()); + let other_specials = HashSet::from_iter(other.specials.clone()); + specials.is_subset(&other_specials) + } + } + + pub fn is_wider_or_equal(&self, other: &TypeSet) -> bool { + set_wider_or_equal(&self.ints, &other.ints) + && set_wider_or_equal(&self.floats, &other.floats) + && set_wider_or_equal(&self.bools, &other.bools) + } + + pub fn is_narrower(&self, other: &TypeSet) -> bool { + set_narrower(&self.ints, &other.ints) + && set_narrower(&self.floats, &other.floats) + && set_narrower(&self.bools, &other.bools) + } +} + +fn set_wider_or_equal(s1: &NumSet, s2: &NumSet) -> bool { + s1.len() > 0 && s2.len() > 0 && s1.iter().min() >= s2.iter().max() +} + +fn set_narrower(s1: &NumSet, s2: &NumSet) -> bool { + s1.len() > 0 && s2.len() > 0 && s1.iter().min() < s2.iter().max() +} + +impl fmt::Debug for TypeSet { + fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { + write!(fmt, "TypeSet(")?; + + let mut subsets = Vec::new(); + if !self.lanes.is_empty() { + subsets.push(format!( + "lanes={{{}}}", + Vec::from_iter(self.lanes.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.ints.is_empty() { + subsets.push(format!( + "ints={{{}}}", + Vec::from_iter(self.ints.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.floats.is_empty() { + subsets.push(format!( + "floats={{{}}}", + Vec::from_iter(self.floats.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.bools.is_empty() { + subsets.push(format!( + "bools={{{}}}", + Vec::from_iter(self.bools.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.bitvecs.is_empty() { + subsets.push(format!( + "bitvecs={{{}}}", + Vec::from_iter(self.bitvecs.iter().map(|x| x.to_string())).join(", ") + )); + } + if !self.specials.is_empty() { + subsets.push(format!( + "specials={{{}}}", + Vec::from_iter(self.specials.iter().map(|x| x.to_string())).join(", ") + )); + } + + write!(fmt, "{})", subsets.join(", "))?; + Ok(()) + } } pub struct TypeSetBuilder { @@ -805,6 +977,136 @@ fn test_forward_images() { ); } +#[test] +fn test_backward_images() { + let empty_set = TypeSetBuilder::new().finish(); + + // LaneOf. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..1) + .ints(8..8) + .floats(32..32) + .finish() + .preimage(DerivedFunc::LaneOf), + TypeSetBuilder::new() + .simd_lanes(Interval::All) + .ints(8..8) + .floats(32..32) + .finish() + ); + assert_eq!(empty_set.preimage(DerivedFunc::LaneOf), empty_set); + + // AsBool. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..4) + .bools(1..64) + .finish() + .preimage(DerivedFunc::AsBool), + TypeSetBuilder::new() + .simd_lanes(1..4) + .ints(Interval::All) + .bools(Interval::All) + .floats(Interval::All) + .finish() + ); + + // Double vector. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..1) + .ints(8..8) + .finish() + .preimage(DerivedFunc::DoubleVector) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..16) + .ints(8..16) + .floats(32..32) + .finish() + .preimage(DerivedFunc::DoubleVector), + TypeSetBuilder::new() + .simd_lanes(1..8) + .ints(8..16) + .floats(32..32) + .finish(), + ); + + // Half vector. + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(256..256) + .ints(8..8) + .finish() + .preimage(DerivedFunc::HalfVector) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(64..128) + .bools(1..32) + .finish() + .preimage(DerivedFunc::HalfVector), + TypeSetBuilder::new() + .simd_lanes(128..256) + .bools(1..32) + .finish(), + ); + + // Half width. + assert_eq!( + TypeSetBuilder::new() + .ints(64..64) + .floats(64..64) + .bools(64..64) + .finish() + .preimage(DerivedFunc::HalfWidth) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(64..256) + .bools(1..64) + .finish() + .preimage(DerivedFunc::HalfWidth), + TypeSetBuilder::new() + .simd_lanes(64..256) + .bools(16..64) + .finish(), + ); + + // Double width. + assert_eq!( + TypeSetBuilder::new() + .ints(8..8) + .floats(32..32) + .bools(1..8) + .finish() + .preimage(DerivedFunc::DoubleWidth) + .size(), + 0 + ); + assert_eq!( + TypeSetBuilder::new() + .simd_lanes(1..16) + .ints(8..16) + .floats(32..64) + .finish() + .preimage(DerivedFunc::DoubleWidth), + TypeSetBuilder::new() + .simd_lanes(1..16) + .ints(8..8) + .floats(32..32) + .finish() + ); +} + #[test] #[should_panic] fn test_typeset_singleton_panic_nonsingleton_types() {