Port widening ops to ISLE (AArch64) (#4751)

Ported the existing implementations of the following opcodes for AArch64
to ISLE, and implemented support for 64-bit vectors (per the docs):
- `SwidenLow`
- `SwidenHigh`
- `UwidenLow`
- `UwidenHigh`

Also ported `WideningPairwiseDotProductS` as-is.

Copyright (c) 2022 Arm Limited
This commit is contained in:
Damian Heaton
2022-08-23 17:42:11 +01:00
committed by GitHub
parent da1fb305a3
commit 3b68d76905
10 changed files with 250 additions and 161 deletions

View File

@@ -528,7 +528,8 @@
(t VecExtendOp)
(rd WritableReg)
(rn Reg)
(high_half bool))
(high_half bool)
(lane_size ScalarSize))
;; Move vector element to another vector element.
(VecMovElement
@@ -1080,18 +1081,10 @@
;; Type of vector element extensions.
(type VecExtendOp
(enum
;; Signed extension of 8-bit elements
(Sxtl8)
;; Signed extension of 16-bit elements
(Sxtl16)
;; Signed extension of 32-bit elements
(Sxtl32)
;; Unsigned extension of 8-bit elements
(Uxtl8)
;; Unsigned extension of 16-bit elements
(Uxtl16)
;; Unsigned extension of 32-bit elements
(Uxtl32)
;; Signed extension
(Sxtl)
;; Unsigned extension
(Uxtl)
))
;; A vector ALU operation.
@@ -1844,6 +1837,12 @@
(_ Unit (emit (MInst.MovFromVecSigned dst rn idx size scalar_size))))
dst))
(decl fpu_move_from_vec (Reg u8 VectorSize) Reg)
(rule (fpu_move_from_vec rn idx size)
(let ((dst WritableReg (temp_writable_reg $I8X16))
(_ Unit (emit (MInst.FpuMoveFromVec dst rn idx size))))
dst))
;; Helper for emitting `MInst.Extend` instructions.
(decl extend (Reg bool u8 u8) Reg)
(rule (extend rn signed from_bits to_bits)
@@ -1858,6 +1857,13 @@
(_ Unit (emit (MInst.FpuExtend dst src size))))
dst))
;; Helper for emitting `MInst.VecExtend` instructions.
(decl vec_extend (VecExtendOp Reg bool ScalarSize) Reg)
(rule (vec_extend op src high_half size)
(let ((dst WritableReg (temp_writable_reg $I8X16))
(_ Unit (emit (MInst.VecExtend op dst src high_half size))))
dst))
;; Helper for emitting `MInst.LoadAcquire` instructions.
(decl load_acquire (Type Reg) Reg)
(rule (load_acquire ty addr)

View File

@@ -653,6 +653,16 @@ impl ScalarSize {
ScalarSize::Size128 => panic!("can't widen 128-bits"),
}
}
pub fn narrow(&self) -> ScalarSize {
match self {
ScalarSize::Size8 => panic!("can't narrow 8-bits"),
ScalarSize::Size16 => ScalarSize::Size8,
ScalarSize::Size32 => ScalarSize::Size16,
ScalarSize::Size64 => ScalarSize::Size32,
ScalarSize::Size128 => ScalarSize::Size64,
}
}
}
/// Type used to communicate the size of a vector operand.

View File

@@ -2382,16 +2382,19 @@ impl MachInstEmit for Inst {
rd,
rn,
high_half,
lane_size,
} => {
let rd = allocs.next_writable(rd);
let rn = allocs.next(rn);
let (u, immh) = match t {
VecExtendOp::Sxtl8 => (0b0, 0b001),
VecExtendOp::Sxtl16 => (0b0, 0b010),
VecExtendOp::Sxtl32 => (0b0, 0b100),
VecExtendOp::Uxtl8 => (0b1, 0b001),
VecExtendOp::Uxtl16 => (0b1, 0b010),
VecExtendOp::Uxtl32 => (0b1, 0b100),
let immh = match lane_size {
ScalarSize::Size16 => 0b001,
ScalarSize::Size32 => 0b010,
ScalarSize::Size64 => 0b100,
_ => panic!("Unexpected VecExtend to lane size of {:?}", lane_size),
};
let u = match t {
VecExtendOp::Sxtl => 0b0,
VecExtendOp::Uxtl => 0b1,
};
sink.put4(
0b000_011110_0000_000_101001_00000_00000

View File

@@ -2581,60 +2581,66 @@ fn test_aarch64_binemit() {
));
insns.push((
Inst::VecExtend {
t: VecExtendOp::Sxtl8,
t: VecExtendOp::Sxtl,
rd: writable_vreg(4),
rn: vreg(27),
high_half: false,
lane_size: ScalarSize::Size16,
},
"64A7080F",
"sxtl v4.8h, v27.8b",
));
insns.push((
Inst::VecExtend {
t: VecExtendOp::Sxtl16,
t: VecExtendOp::Sxtl,
rd: writable_vreg(17),
rn: vreg(19),
high_half: true,
lane_size: ScalarSize::Size32,
},
"71A6104F",
"sxtl2 v17.4s, v19.8h",
));
insns.push((
Inst::VecExtend {
t: VecExtendOp::Sxtl32,
t: VecExtendOp::Sxtl,
rd: writable_vreg(30),
rn: vreg(6),
high_half: false,
lane_size: ScalarSize::Size64,
},
"DEA4200F",
"sxtl v30.2d, v6.2s",
));
insns.push((
Inst::VecExtend {
t: VecExtendOp::Uxtl8,
t: VecExtendOp::Uxtl,
rd: writable_vreg(3),
rn: vreg(29),
high_half: true,
lane_size: ScalarSize::Size16,
},
"A3A7086F",
"uxtl2 v3.8h, v29.16b",
));
insns.push((
Inst::VecExtend {
t: VecExtendOp::Uxtl16,
t: VecExtendOp::Uxtl,
rd: writable_vreg(15),
rn: vreg(12),
high_half: false,
lane_size: ScalarSize::Size32,
},
"8FA5102F",
"uxtl v15.4s, v12.4h",
));
insns.push((
Inst::VecExtend {
t: VecExtendOp::Uxtl32,
t: VecExtendOp::Uxtl,
rd: writable_vreg(28),
rn: vreg(2),
high_half: true,
lane_size: ScalarSize::Size64,
},
"5CA4206F",
"uxtl2 v28.2d, v2.4s",

View File

@@ -2041,47 +2041,19 @@ impl Inst {
rd,
rn,
high_half,
lane_size,
} => {
let (op, dest, src) = match (t, high_half) {
(VecExtendOp::Sxtl8, false) => {
("sxtl", VectorSize::Size16x8, VectorSize::Size8x8)
}
(VecExtendOp::Sxtl8, true) => {
("sxtl2", VectorSize::Size16x8, VectorSize::Size8x16)
}
(VecExtendOp::Sxtl16, false) => {
("sxtl", VectorSize::Size32x4, VectorSize::Size16x4)
}
(VecExtendOp::Sxtl16, true) => {
("sxtl2", VectorSize::Size32x4, VectorSize::Size16x8)
}
(VecExtendOp::Sxtl32, false) => {
("sxtl", VectorSize::Size64x2, VectorSize::Size32x2)
}
(VecExtendOp::Sxtl32, true) => {
("sxtl2", VectorSize::Size64x2, VectorSize::Size32x4)
}
(VecExtendOp::Uxtl8, false) => {
("uxtl", VectorSize::Size16x8, VectorSize::Size8x8)
}
(VecExtendOp::Uxtl8, true) => {
("uxtl2", VectorSize::Size16x8, VectorSize::Size8x16)
}
(VecExtendOp::Uxtl16, false) => {
("uxtl", VectorSize::Size32x4, VectorSize::Size16x4)
}
(VecExtendOp::Uxtl16, true) => {
("uxtl2", VectorSize::Size32x4, VectorSize::Size16x8)
}
(VecExtendOp::Uxtl32, false) => {
("uxtl", VectorSize::Size64x2, VectorSize::Size32x2)
}
(VecExtendOp::Uxtl32, true) => {
("uxtl2", VectorSize::Size64x2, VectorSize::Size32x4)
}
let vec64 = VectorSize::from_lane_size(lane_size.narrow(), false);
let vec128 = VectorSize::from_lane_size(lane_size.narrow(), true);
let rd_size = VectorSize::from_lane_size(lane_size, true);
let (op, rn_size) = match (t, high_half) {
(VecExtendOp::Sxtl, false) => ("sxtl", vec64),
(VecExtendOp::Sxtl, true) => ("sxtl2", vec128),
(VecExtendOp::Uxtl, false) => ("uxtl", vec64),
(VecExtendOp::Uxtl, true) => ("uxtl2", vec128),
};
let rd = pretty_print_vreg_vector(rd.to_reg(), dest, allocs);
let rn = pretty_print_vreg_vector(rn, src, allocs);
let rd = pretty_print_vreg_vector(rd.to_reg(), rd_size, allocs);
let rn = pretty_print_vreg_vector(rn, rn_size, allocs);
format!("{} {}, {}", op, rd, rn)
}
&Inst::VecMovElement {

View File

@@ -1817,6 +1817,48 @@
(result Reg (uqxtn2 low_half y (lane_size ty))))
result))
;;;; Rules for `swiden_low` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (swiden_low x)))
(vec_extend (VecExtendOp.Sxtl) x $false (lane_size ty)))
;;;; Rules for `swiden_high` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type (ty_vec128 ty) (swiden_high x)))
(vec_extend (VecExtendOp.Sxtl) x $true (lane_size ty)))
(rule (lower (has_type ty (swiden_high x)))
(if (ty_vec64 ty))
(let ((tmp Reg (fpu_move_from_vec x 1 (VectorSize.Size32x2))))
(vec_extend (VecExtendOp.Sxtl) tmp $false (lane_size ty))))
;;;; Rules for `uwiden_low` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (uwiden_low x)))
(vec_extend (VecExtendOp.Uxtl) x $false (lane_size ty)))
;;;; Rules for `uwiden_high` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type (ty_vec128 ty) (uwiden_high x)))
(vec_extend (VecExtendOp.Uxtl) x $true (lane_size ty)))
(rule (lower (has_type ty (uwiden_high x)))
(if (ty_vec64 ty))
(let ((tmp Reg (fpu_move_from_vec x 1 (VectorSize.Size32x2))))
(vec_extend (VecExtendOp.Uxtl) tmp $false (lane_size ty))))
;;;; Rules for `widening_pairwise_dot_product_s` ;;;;;;;;;;;;;;;;;;;;;;
;; The args have type I16X8.
;; "dst = i32x4.dot_i16x8_s(x, y)"
;; => smull tmp, x, y
;; smull2 dst, x, y
;; addp dst, tmp, dst
(rule (lower (has_type $I32X4 (widening_pairwise_dot_product_s x y)))
(let ((tmp Reg (vec_rrr_long (VecRRRLongOp.Smull16) x y $false))
(dst Reg (vec_rrr_long (VecRRRLongOp.Smull16) x y $true)))
(vec_rrr (VecALUOp.Addp) tmp dst (VectorSize.Size32x4))))
;;;; Rules for `Fence` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (fence))

View File

@@ -98,3 +98,23 @@
;;; Rules for `extract_vector` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (extract_vector x 0))
(value_reg (fpu_move_128 (put_in_reg x))))
;;;; Rules for `swiden_low` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (swiden_low x)))
(value_reg (vec_extend (VecExtendOp.Sxtl) x $false (lane_size ty))))
;;;; Rules for `swiden_high` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (swiden_high x)))
(value_reg (vec_extend (VecExtendOp.Sxtl) x $true (lane_size ty))))
;;;; Rules for `uwiden_low` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (uwiden_low x)))
(value_reg (vec_extend (VecExtendOp.Uxtl) x $false (lane_size ty))))
;;;; Rules for `uwiden_high` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
(rule (lower (has_type ty (uwiden_high x)))
(value_reg (vec_extend (VecExtendOp.Uxtl) x $true (lane_size ty))))

View File

@@ -159,22 +159,23 @@ pub(crate) fn lower_insn_to_regs(
});
let vec_extend = match op {
Opcode::Sload8x8 => Some(VecExtendOp::Sxtl8),
Opcode::Uload8x8 => Some(VecExtendOp::Uxtl8),
Opcode::Sload16x4 => Some(VecExtendOp::Sxtl16),
Opcode::Uload16x4 => Some(VecExtendOp::Uxtl16),
Opcode::Sload32x2 => Some(VecExtendOp::Sxtl32),
Opcode::Uload32x2 => Some(VecExtendOp::Uxtl32),
Opcode::Sload8x8 => Some((VecExtendOp::Sxtl, ScalarSize::Size16)),
Opcode::Uload8x8 => Some((VecExtendOp::Uxtl, ScalarSize::Size16)),
Opcode::Sload16x4 => Some((VecExtendOp::Sxtl, ScalarSize::Size32)),
Opcode::Uload16x4 => Some((VecExtendOp::Uxtl, ScalarSize::Size32)),
Opcode::Sload32x2 => Some((VecExtendOp::Sxtl, ScalarSize::Size64)),
Opcode::Uload32x2 => Some((VecExtendOp::Uxtl, ScalarSize::Size64)),
_ => None,
};
if let Some(t) = vec_extend {
if let Some((t, lane_size)) = vec_extend {
let rd = dst.only_reg().unwrap();
ctx.emit(Inst::VecExtend {
t,
rd,
rn: rd.to_reg(),
high_half: false,
lane_size,
});
}
@@ -961,46 +962,7 @@ pub(crate) fn lower_insn_to_regs(
Opcode::IaddPairwise => implemented_in_isle(ctx),
Opcode::WideningPairwiseDotProductS => {
let r_y = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
let r_a = put_input_in_reg(ctx, inputs[0], NarrowValueMode::None);
let r_b = put_input_in_reg(ctx, inputs[1], NarrowValueMode::None);
let ty = ty.unwrap();
if ty == I32X4 {
let tmp = ctx.alloc_tmp(I8X16).only_reg().unwrap();
// The args have type I16X8.
// "y = i32x4.dot_i16x8_s(a, b)"
// => smull tmp, a, b
// smull2 y, a, b
// addp y, tmp, y
ctx.emit(Inst::VecRRRLong {
alu_op: VecRRRLongOp::Smull16,
rd: tmp,
rn: r_a,
rm: r_b,
high_half: false,
});
ctx.emit(Inst::VecRRRLong {
alu_op: VecRRRLongOp::Smull16,
rd: r_y,
rn: r_a,
rm: r_b,
high_half: true,
});
ctx.emit(Inst::VecRRR {
alu_op: VecALUOp::Addp,
rd: r_y,
rn: tmp.to_reg(),
rm: r_y.to_reg(),
size: VectorSize::Size32x4,
});
} else {
return Err(CodegenError::Unsupported(format!(
"Opcode::WideningPairwiseDotProductS: unsupported laneage: {:?}",
ty
)));
}
}
Opcode::WideningPairwiseDotProductS => implemented_in_isle(ctx),
Opcode::Fadd | Opcode::Fsub | Opcode::Fmul | Opcode::Fdiv | Opcode::Fmin | Opcode::Fmax => {
implemented_in_isle(ctx)
@@ -1485,42 +1447,7 @@ pub(crate) fn lower_insn_to_regs(
Opcode::Snarrow | Opcode::Unarrow | Opcode::Uunarrow => implemented_in_isle(ctx),
Opcode::SwidenLow | Opcode::SwidenHigh | Opcode::UwidenLow | Opcode::UwidenHigh => {
let rd = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
let rn = put_input_in_reg(ctx, inputs[0], NarrowValueMode::None);
let ty = ty.unwrap();
let ty = if ty.is_dynamic_vector() {
ty.dynamic_to_vector()
.unwrap_or_else(|| panic!("Unsupported dynamic type: {}?", ty))
} else {
ty
};
let (t, high_half) = match (ty, op) {
(I16X8, Opcode::SwidenLow) => (VecExtendOp::Sxtl8, false),
(I16X8, Opcode::SwidenHigh) => (VecExtendOp::Sxtl8, true),
(I16X8, Opcode::UwidenLow) => (VecExtendOp::Uxtl8, false),
(I16X8, Opcode::UwidenHigh) => (VecExtendOp::Uxtl8, true),
(I32X4, Opcode::SwidenLow) => (VecExtendOp::Sxtl16, false),
(I32X4, Opcode::SwidenHigh) => (VecExtendOp::Sxtl16, true),
(I32X4, Opcode::UwidenLow) => (VecExtendOp::Uxtl16, false),
(I32X4, Opcode::UwidenHigh) => (VecExtendOp::Uxtl16, true),
(I64X2, Opcode::SwidenLow) => (VecExtendOp::Sxtl32, false),
(I64X2, Opcode::SwidenHigh) => (VecExtendOp::Sxtl32, true),
(I64X2, Opcode::UwidenLow) => (VecExtendOp::Uxtl32, false),
(I64X2, Opcode::UwidenHigh) => (VecExtendOp::Uxtl32, true),
(ty, _) => {
return Err(CodegenError::Unsupported(format!(
"{}: Unsupported type: {:?}",
op, ty
)));
}
};
ctx.emit(Inst::VecExtend {
t,
rd,
rn,
high_half,
});
implemented_in_isle(ctx)
}
Opcode::TlsValue => match flags.tls_model() {
@@ -1557,10 +1484,11 @@ pub(crate) fn lower_insn_to_regs(
let rn = put_input_in_reg(ctx, inputs[0], NarrowValueMode::None);
ctx.emit(Inst::VecExtend {
t: VecExtendOp::Sxtl32,
t: VecExtendOp::Sxtl,
rd,
rn,
high_half: false,
lane_size: ScalarSize::Size64,
});
ctx.emit(Inst::VecMisc {
op: VecMisc2::Scvtf,

View File

@@ -60,9 +60,9 @@ block0(v0: i8x16, v1: i8x16):
}
; block0:
; sxtl v4.8h, v0.8b
; sxtl2 v6.8h, v1.16b
; addp v0.8h, v4.8h, v6.8h
; sxtl v7.8h, v0.8b
; sxtl2 v16.8h, v1.16b
; addp v0.8h, v7.8h, v16.8h
; ret
function %fn6(i8x16, i8x16) -> i16x8 {
@@ -74,9 +74,9 @@ block0(v0: i8x16, v1: i8x16):
}
; block0:
; uxtl v4.8h, v0.8b
; uxtl2 v6.8h, v1.16b
; addp v0.8h, v4.8h, v6.8h
; uxtl v7.8h, v0.8b
; uxtl2 v16.8h, v1.16b
; addp v0.8h, v7.8h, v16.8h
; ret
function %fn7(i8x16) -> i16x8 {
@@ -88,9 +88,9 @@ block0(v0: i8x16):
}
; block0:
; uxtl v2.8h, v0.8b
; sxtl2 v4.8h, v0.16b
; addp v0.8h, v2.8h, v4.8h
; uxtl v5.8h, v0.8b
; sxtl2 v6.8h, v0.16b
; addp v0.8h, v5.8h, v6.8h
; ret
function %fn8(i8x16) -> i16x8 {
@@ -102,9 +102,9 @@ block0(v0: i8x16):
}
; block0:
; sxtl v2.8h, v0.8b
; uxtl2 v4.8h, v0.16b
; addp v0.8h, v2.8h, v4.8h
; sxtl v5.8h, v0.8b
; uxtl2 v6.8h, v0.16b
; addp v0.8h, v5.8h, v6.8h
; ret
function %fn9(i8x8, i8x8) -> i8x8 {
@@ -157,3 +157,63 @@ block0(v0: i32x4, v1: i32x4):
; addp v0.4s, v0.4s, v1.4s
; ret
function %fn15(i8x8, i8x8) -> i16x4 {
block0(v0: i8x8, v1: i8x8):
v2 = swiden_low v0
v3 = swiden_high v1
v4 = iadd_pairwise v2, v3
return v4
}
; block0:
; sxtl v16.8h, v0.8b
; mov s7, v1.s[1]
; sxtl v17.8h, v7.8b
; addp v0.4h, v16.4h, v17.4h
; ret
function %fn16(i8x8, i8x8) -> i16x4 {
block0(v0: i8x8, v1: i8x8):
v2 = uwiden_low v0
v3 = uwiden_high v1
v4 = iadd_pairwise v2, v3
return v4
}
; block0:
; uxtl v16.8h, v0.8b
; mov s7, v1.s[1]
; uxtl v17.8h, v7.8b
; addp v0.4h, v16.4h, v17.4h
; ret
function %fn17(i8x8) -> i16x4 {
block0(v0: i8x8):
v1 = uwiden_low v0
v2 = swiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}
; block0:
; uxtl v6.8h, v0.8b
; mov s5, v0.s[1]
; sxtl v7.8h, v5.8b
; addp v0.4h, v6.4h, v7.4h
; ret
function %fn18(i8x8) -> i16x4 {
block0(v0: i8x8):
v1 = swiden_low v0
v2 = uwiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}
; block0:
; sxtl v6.8h, v0.8b
; mov s5, v0.s[1]
; uxtl v7.8h, v5.8b
; addp v0.4h, v6.4h, v7.4h
; ret

View File

@@ -24,3 +24,45 @@ block0(v0: i32x2, v1: i32x2):
}
; run: %iaddp_i32x2([1 2], [5 6]) == [3 11]
; run: %iaddp_i32x2([4294967290 5], [100 100]) == [4294967295 200]
function %swiden_i8x8(i8x8) -> i16x4 {
block0(v0: i8x8):
v1 = swiden_low v0
v2 = swiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}
; run: %swiden_i8x8([1 2 3 4 5 6 7 8]) == [3 7 11 15]
; run: %swiden_i8x8([-1 2 -3 4 -5 6 -7 8]) == [1 1 1 1]
; run: %swiden_i8x8([127 1 126 2 125 3 124 4]) == [128 128 128 128]
function %uwiden_i8x8(i8x8) -> i16x4 {
block0(v0: i8x8):
v1 = uwiden_low v0
v2 = uwiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}
; run: %uwiden_i8x8([17 18 19 20 21 22 23 24]) == [35 39 43 47]
; run: %uwiden_i8x8([2 254 3 253 4 252 5 251]) == [256 256 256 256]
function %swiden_i16x4(i16x4) -> i32x2 {
block0(v0: i16x4):
v1 = swiden_low v0
v2 = swiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}
; run: %swiden_i16x4([1 2 3 4]) == [3 7]
; run: %swiden_i16x4([-1 2 -3 4]) == [1 1]
; run: %swiden_i16x4([127 1 126 2]) == [128 128]
function %uwiden_i16x4(i16x4) -> i32x2 {
block0(v0: i16x4):
v1 = uwiden_low v0
v2 = uwiden_high v0
v3 = iadd_pairwise v1, v2
return v3
}
; run: %uwiden_i16x4([17 18 19 20]) == [35 39]
; run: %uwiden_i16x4([2 254 3 253]) == [256 256]