decode: Add AVX-512 support

This commit is contained in:
Alexis Engelke
2022-10-02 11:57:39 +02:00
parent ec5a430b5c
commit e04aff73dc
8 changed files with 2265 additions and 58 deletions

View File

@@ -14,7 +14,8 @@ INSTR_FLAGS_FIELDS, INSTR_FLAGS_SIZES = zip(*[
("modreg_idx", 2),
("vexreg_idx", 2), # note: vexreg w/o vex prefix is zeroreg_val
("imm_idx", 2),
("unused1", 2),
("evex_bcst", 1),
("evex_mask", 1),
("zeroreg_val", 1),
("lock", 1),
("imm_control", 3),
@@ -31,7 +32,8 @@ INSTR_FLAGS_FIELDS, INSTR_FLAGS_SIZES = zip(*[
("modreg_ty", 3),
("vexreg_ty", 2),
("imm_ty", 0),
("unused", 3),
("evex_rc", 2),
("unused", 1),
("opsize", 3),
("modrm", 1),
("ign66", 1),
@@ -141,9 +143,8 @@ OPKIND_SIZES = {
"zq": 8, # z-immediate, but always 8-byte operand
}
class OpKind(NamedTuple):
kind: str
regkind: str
sizestr: str
size: int
SZ_OP = -1
SZ_VEC = -2
@@ -163,9 +164,15 @@ class OpKind(NamedTuple):
def immsize(self, opsz):
maxsz = 1 if self.sizestr == "bs" else 4 if self.sizestr[0] == "z" else 8
return min(maxsz, self.abssize(opsz))
@property
def kind(self):
return OPKIND_CANONICALIZE[self.regkind]
@property
def size(self):
return OPKIND_SIZES[self.sizestr]
@classmethod
def parse(cls, op):
return cls(OPKIND_CANONICALIZE[op[0]], op[1:], OPKIND_SIZES[op[1:]])
return cls(op[0], op[1:])
class InstrDesc(NamedTuple):
mnemonic: str
@@ -185,25 +192,28 @@ class InstrDesc(NamedTuple):
("modrm", "MEM"): 0,
("imm", "MEM"): 0, ("imm", "IMM"): 0, ("imm", "XMM"): 0,
}
OPKIND_REGTYS_ENC = {"SEG": 3, "FPU": 4, "MMX": 5, "XMM": 6, "BND": 8,
"CR": 9, "DR": 10}
OPKIND_REGTYS_ENC = {"SEG": 3, "FPU": 4, "MMX": 5, "XMM": 6, "MASK": 7,
"BND": 8, "CR": 9, "DR": 10}
OPKIND_SIZES = {
0: 0, 1: 1, 2: 2, 4: 3, 8: 4, 16: 5, 32: 6, 64: 7, 10: 0,
# OpKind.SZ_OP: -2, OpKind.SZ_VEC: -3, OpKind.SZ_HALFVEC: -4,
}
@classmethod
def parse(cls, desc):
desc = desc.split()
mnem_comp = desc[5].split("+", 1)
desc[5] = mnem_comp[0]
if len(mnem_comp) > 1 and "w" in mnem_comp[1]:
desc.append("INSTR_WIDTH")
if len(mnem_comp) > 1 and "a" in mnem_comp[1]:
desc.append("U67")
if len(mnem_comp) > 1 and "s" in mnem_comp[1]:
desc.append("USEG")
mnem, _, compactDesc = desc[5].partition("+")
flags = frozenset(desc[6:] + [{
"w": "INSTR_WIDTH",
"a": "U67",
"s": "USEG",
"k": "MASK",
"b": "BCST",
"e": "SAE",
"r": "ER",
}[c] for c in compactDesc])
operands = tuple(OpKind.parse(op) for op in desc[1:5] if op != "-")
return cls(desc[5], desc[0], operands, frozenset(desc[6:]))
return cls(mnem, desc[0], operands, flags)
def imm_size(self, opsz):
flags = ENCODINGS[self.encoding]
@@ -297,6 +307,10 @@ class InstrDesc(NamedTuple):
# Miscellaneous Flags
if "VSIB" in self.flags: extraflags["vsib"] = 1
if "BCST" in self.flags: extraflags["evex_bcst"] = 1
if "MASK" in self.flags: extraflags["evex_mask"] = 1
if "SAE" in self.flags: extraflags["evex_rc"] = 1
if "ER" in self.flags: extraflags["evex_rc"] = 3
if modrm: extraflags["modrm"] = 1
if "U66" not in self.flags and (ign66 or "I66" in self.flags):
@@ -322,8 +336,8 @@ class EntryKind(Enum):
return self == EntryKind.INSTR or self == EntryKind.WEAKINSTR
opcode_regex = re.compile(
r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
r"(?:W(?P<rexw>[01]|IG)\.)?(?:L(?P<vexl>[01]|IG)\.)?))?" +
r"^(?:(?P<prefixes>(?P<vex>E?VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
r"(?:W(?P<rexw>[01]|IG)\.)?(?:L(?P<vexl>0|1|12|2|IG)\.)?))?" +
r"(?P<escape>0f38|0f3a|0f|)" +
r"(?P<opcode>[0-9a-f]{2})" +
r"(?:/(?P<modreg>[0-7]|[rm]|[0-7][rm])|(?P<opcext>[c-f][0-9a-f]))?(?P<extended>\+)?$")
@@ -335,8 +349,8 @@ class Opcode(NamedTuple):
extended: bool # Extend opc or opcext, if present
modreg: Union[None, Tuple[Union[None, int], str]] # (modreg, "r"/"m"/"rm"), None
opcext: Union[None, int] # 0xc0-0xff, or 0
vex: bool
vexl: Union[str, None] # 0, 1, IG, None = used, both
vex: int # 0 = legacy, 1 = VEX, 2 = EVEX
vexl: Union[str, None] # 0, 1, 12, 2, IG, None = used, both
rexw: Union[str, None] # 0, 1, IG, None = used, both
@classmethod
@@ -360,11 +374,71 @@ class Opcode(NamedTuple):
extended=match.group("extended") is not None,
modreg=modreg,
opcext=int(match.group("opcext") or "0", 16) or None,
vex=match.group("vex") is not None,
vex=[None, "VEX.", "EVEX."].index(match.group("vex")),
vexl=match.group("vexl"),
rexw=match.group("rexw"),
)
def verifyOpcodeDesc(opcode, desc):
flags = ENCODINGS[desc.encoding]
if opcode.escape == 2 and flags.imm_control != 0:
raise Exception(f"0f38 has no immediate operand {opcode}, {desc}")
if opcode.escape == 3 and desc.imm_size(4) != 1:
raise Exception(f"0f3a must have immediate byte {opcode}, {desc}")
if opcode.vexl == "IG" and desc.dynsizes() - {OpKind.SZ_OP}:
raise Exception(f"(E)VEX.LIG with dynamic vector size {opcode}, {desc}")
if "VSIB" in desc.flags and (not opcode.modreg or opcode.modreg[1] != "m"):
raise Exception(f"VSIB for non-memory opcode {opcode}, {desc}")
if opcode.vex == 2 and flags.vexreg_idx:
# Checking this here allows to omit check for V' in decoder.
if desc.operands[flags.vexreg_idx ^ 3].kind != "XMM":
raise Exception(f"EVEX.vvvv must refer to XMM {opcode}, {desc}")
if opcode.vex == 2 and flags.modreg_idx and flags.modreg_idx ^ 3 != 0:
# EVEX.z=0 is only checked for mask operands in ModReg
if desc.operands[flags.modreg_idx ^ 3].kind == "MASK":
raise Exception(f"ModRM.reg mask not first operand {opcode}, {desc}")
# Verify tuple type
if opcode.vex == 2 and (not opcode.modreg or "m" in opcode.modreg[1]):
tts = [s for s in desc.flags if s.startswith("TUPLE")]
if len(tts) != 1:
raise Exception(f"missing tuple type in {opcode}, {desc}")
if flags.modrm_idx == 3 ^ 3:
raise Exception(f"missing memory operand {opcode}, {desc}")
# From Intel SDM
bcst, evexw, vszs = {
"TUPLE_FULL_32": (True, "0", ( 16, 32, 64)),
"TUPLE_FULL_64": (True, "1", ( 16, 32, 64)),
"TUPLE_HALF_32": (True, "0", ( 8, 16, 32)),
"TUPLE_HALF_64": (True, "1", ( 8, 16, 32)),
"TUPLE_FULL_MEM": (False, None, ( 16, 32, 64)),
"TUPLE_HALF_MEM": (False, None, ( 8, 16, 32)),
"TUPLE_QUARTER_MEM": (False, None, ( 4, 8, 16)),
"TUPLE_EIGHTH_MEM": (False, None, ( 2, 4, 8)),
"TUPLE1_SCALAR_8": (False, None, ( 1, 1, 1)),
"TUPLE1_SCALAR_16": (False, None, ( 2, 2, 2)),
"TUPLE1_SCALAR_32": (False, "0", ( 4, 4, 4)),
"TUPLE1_SCALAR_64": (False, "1", ( 8, 8, 8)),
"TUPLE1_SCALAR_OPSZ": (False, None, ( 0, 0, 0)),
"TUPLE1_FIXED_32": (False, None, ( 4, 4, 4)),
"TUPLE1_FIXED_64": (False, None, ( 8, 8, 8)),
"TUPLE2_32": (False, "0", ( 8, 8, 8)),
"TUPLE2_64": (False, "1", (None, 16, 16)),
"TUPLE4_32": (False, "0", (None, 16, 16)),
"TUPLE4_64": (False, "1", (None, None, 32)),
"TUPLE8_32": (False, "0", (None, None, 32)),
"TUPLE_MEM128": (False, None, ( 16, 16, 16)),
# TODO: Fix MOVDDUP tuple size :(
"TUPLE_MOVDDUP": (False, None, ( 16, 32, 64)),
}[tts[0]]
if "BCST" in desc.flags and not bcst:
raise Exception(f"broadcast on incompatible type {opcode}, {desc}")
if evexw and opcode.rexw != evexw:
raise Exception(f"incompatible EVEX.W {opcode}, {desc}")
for l, tupsz in enumerate(vszs):
opsz = desc.operands[flags.modrm_idx ^ 3].abssize(0, 16 << l)
if tupsz is not None and opsz != tupsz:
raise Exception(f"memory size {opsz} != {tupsz} {opcode}, {desc}")
class Trie:
KIND_ORDER = (EntryKind.TABLE_ROOT, EntryKind.TABLE256,
EntryKind.TABLE_PREFIX, EntryKind.TABLE16,
@@ -375,7 +449,7 @@ class Trie:
EntryKind.TABLE_PREFIX: 4,
EntryKind.TABLE16: 16,
EntryKind.TABLE8E: 8,
EntryKind.TABLE_VEX: 4,
EntryKind.TABLE_VEX: 8,
}
def __init__(self, root_count):
@@ -412,9 +486,12 @@ class Trie:
mod = {"m": [0], "r": [1<<3], "rm": [0, 1<<3]}[opc.modreg[1]]
reg = [opc.modreg[0]] if opc.modreg[0] is not None else list(range(8))
t16 = [x + y for x in mod for y in reg]
if opc.vexl in ("0", "1") or opc.rexw in ("0", "1"):
if (opc.rexw or "IG") != "IG" or (opc.vexl or "IG") != "IG":
rexw = {"0": [0], "1": [1<<0], "IG": [0, 1<<0]}[opc.rexw or "IG"]
vexl = {"0": [0], "1": [1<<1], "IG": [0, 1<<1]}[opc.vexl or "IG"]
if opc.vex < 2:
vexl = {"0": [0], "1": [1<<1], "IG": [0, 1<<1]}[opc.vexl or "IG"]
else:
vexl = {"0": [0], "12": [1<<1, 2<<1], "2": [2<<1], "IG": [0, 1<<1, 2<<1, 3<<1]}[opc.vexl or "IG"]
tvex = list(map(sum, product(rexw, vexl)))
# Order must match KIND_ORDER.
return troot, t256, tprefix, t16, t8e, tvex
@@ -566,9 +643,11 @@ def decode_table(entries, args):
decode_mnems_lines = [f"FD_MNEMONIC({m},{i})\n" for i, m in enumerate(mnems)]
mnemonics_intel = [m.replace("SSE_", "").replace("MMX_", "")
.replace("EVX_", "V")
.replace("MOVABS", "MOV").replace("RESERVED_", "")
.replace("JMPF", "JMP FAR").replace("CALLF", "CALL FAR")
.replace("_S2G", "").replace("_G2S", "")
.replace("_X2G", "").replace("_G2X", "")
.replace("_CR", "").replace("_DR", "")
.replace("REP_", "REP ").replace("CMPXCHGD", "CMPXCHG")
.replace("JCXZ", "JCXZ JECXZJRCXZ")
@@ -608,6 +687,8 @@ def encode_mnems(entries):
for weak, opcode, desc in entries:
if "I64" in desc.flags or desc.mnemonic[:9] == "RESERVED_":
continue
if opcode.vex == 2: # EVEX not implemented
continue
opsizes, vecsizes = {0}, {0}
prepend_opsize, prepend_vecsize = False, False
@@ -631,7 +712,7 @@ def encode_mnems(entries):
opsizes = {64}
prepend_opsize = False
elif opcode.vex and opcode.vexl != "IG": # vectors; don't care for SSE
vecsizes = {128, 256}
vecsizes = {128, 256} # TODO-EVEX
if opcode.vexl:
vecsizes -= {128 if opcode.vexl == "1" else 256}
prepend_vecsize = not separate_opsize
@@ -718,8 +799,10 @@ def encode_table(entries, args):
opc_i |= 0x400000 if opcode.rexw == "1" else 0
if opcode.prefix == "LOCK":
opc_i |= 0x800000
elif opcode.vex:
elif opcode.vex == 1:
opc_i |= 0x1000000 + 0x800000 * int(opcode.vexl or 0)
elif opcode.vex == 2: # TODO-EVEX
opc_i |= 0x2000000 + 0x800000 * int(opcode.vexl or 0)
opc_i |= 0x8000000 if "VSIB" in desc.flags else 0
if alt >= 0x100:
raise Exception("encode alternate bits exhausted")
@@ -831,7 +914,7 @@ def encode2_table(entries, args):
code += f" if (!op_imm_n(imm-1, imm_size)) goto next{i};\n"
neednext = True
if opcode.vex:
if opcode.vex: # TODO-EVEX
rexw, rexr, rexx, rexb = 0x8000, 0x80, 0x40, 0x20
else:
rexw, rexr, rexx, rexb = 0x48, 0x44, 0x42, 0x41
@@ -864,7 +947,7 @@ def encode2_table(entries, args):
if "m" in ots or "U67" in desc.flags:
code += " if (UNLIKELY(flags & FE_ADDR32)) buf[idx++] = 0x67;\n"
if opcode.vex:
if opcode.vex: # TODO-EVEX
ppl = ["NP", "66", "F3", "F2"].index(opcode.prefix)
ppl |= 4 if opcode.vexl == "1" else 0
mayvex2 = opcode.rexw != "1" and opcode.escape == 1
@@ -957,6 +1040,7 @@ if __name__ == "__main__":
line, weak = (line, False) if line[0] != "*" else (line[1:], True)
opcode_string, desc_string = tuple(line.split(maxsplit=1))
opcode, desc = Opcode.parse(opcode_string), InstrDesc.parse(desc_string)
verifyOpcodeDesc(opcode, desc)
if "UNDOC" not in desc.flags or args.with_undoc:
entries.append((weak, opcode, desc))