decode: Encode prefixes in trie

This allows to handle unescaped opcodes with a single table lookup.
This commit is contained in:
Alexis Engelke
2023-03-22 09:57:58 +01:00
parent a34cd9d2aa
commit e1084be859
2 changed files with 107 additions and 41 deletions

View File

@@ -38,14 +38,19 @@ typedef enum DecodeMode DecodeMode;
#define ENTRY_TABLE_ROOT 8
#define ENTRY_MASK 7
static unsigned
table_walk(unsigned cur_idx, unsigned entry_idx, unsigned* out_kind) {
static uint16_t
table_lookup(unsigned cur_idx, unsigned entry_idx) {
static _Alignas(16) const uint16_t _decode_table[] = {
#define FD_DECODE_TABLE_DATA
#include <fadec-decode-private.inc>
#undef FD_DECODE_TABLE_DATA
};
unsigned entry = _decode_table[cur_idx + entry_idx];
return _decode_table[cur_idx + entry_idx];
}
static unsigned
table_walk(unsigned cur_idx, unsigned entry_idx, unsigned* out_kind) {
unsigned entry = table_lookup(cur_idx, entry_idx);
*out_kind = entry & ENTRY_MASK;
return (entry & ~ENTRY_MASK) >> 1;
}
@@ -126,37 +131,33 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
uint8_t addr_size = mode == DECODE_64 ? 3 : 2;
unsigned prefix_rex = 0;
uint8_t prefix_rep = 0;
unsigned vexl = 0;
unsigned prefix_evex = 0;
instr->segment = FD_REG_NONE;
// Values must match prefixes in parseinstrs.py.
enum {
PF_SEG1 = 1|8,
PF_SEG2 = 1,
PF_66 = 3,
PF_67 = 4,
PF_LOCK = 5,
PF_REP = 6,
PF_REX = 8,
PF_SEG1 = 0xfff8 - 0xfff8,
PF_SEG2 = 0xfff9 - 0xfff8,
PF_66 = 0xfffa - 0xfff8,
PF_67 = 0xfffb - 0xfff8,
PF_LOCK = 0xfffc - 0xfff8,
PF_REP = 0xfffd - 0xfff8,
PF_REX = 0xfffe - 0xfff8,
};
static const uint8_t pflut[256] = {
[0x26] = PF_SEG1, [0x2e] = PF_SEG1, [0x36] = PF_SEG1, [0x3e] = PF_SEG1,
[0x64] = PF_SEG2, [0x65] = PF_SEG2, [0x66] = PF_66, [0x67] = PF_67,
[0xf0] = PF_LOCK, [0xf2] = PF_REP, [0xf3] = PF_REP,
[0x40] = PF_REX, [0x41] = PF_REX, [0x42] = PF_REX, [0x43] = PF_REX,
[0x44] = PF_REX, [0x45] = PF_REX, [0x46] = PF_REX, [0x47] = PF_REX,
[0x48] = PF_REX, [0x49] = PF_REX, [0x4a] = PF_REX, [0x4b] = PF_REX,
[0x4c] = PF_REX, [0x4d] = PF_REX, [0x4e] = PF_REX, [0x4f] = PF_REX,
};
uint8_t prefixes[12] = {0};
uint8_t lutmask = mode == DECODE_64 ? 0xff : 0x7;
while (LIKELY(off < len)) {
uint8_t prefixes[8] = {0};
unsigned table_entry = 0;
while (true) {
if (UNLIKELY(off >= len))
return FD_ERR_PARTIAL;
uint8_t prefix = buffer[off];
uint8_t lut = pflut[prefix] & lutmask;
if (LIKELY(!lut))
table_entry = table_lookup(table_idx, prefix);
if (LIKELY(table_entry - 0xfff8 >= 8))
break;
prefixes[lut] = prefix;
prefixes[PF_REX] = 0;
prefixes[table_entry - 0xfff8] = prefix;
off++;
}
if (off) {
@@ -168,10 +169,28 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
}
if (UNLIKELY(prefixes[PF_67]))
addr_size--;
if (buffer[off - 1] == prefixes[PF_REX])
prefix_rex = prefixes[PF_REX];
prefix_rex = prefixes[PF_REX];
prefix_rep = prefixes[PF_REP];
}
kind = table_entry & ENTRY_MASK;
table_idx = (table_entry & ~ENTRY_MASK) >> 1;
if (LIKELY(kind != 7)) {
off++;
// Then, walk through ModR/M-encoded opcode extensions.
if (kind == ENTRY_TABLE16 && LIKELY(off < len)) {
unsigned isreg = (buffer[off] & 0xc0) == 0xc0 ? 8 : 0;
table_idx = table_walk(table_idx, ((buffer[off] >> 3) & 7) | isreg, &kind);
if (kind == ENTRY_TABLE8E)
table_idx = table_walk(table_idx, buffer[off] & 7, &kind);
}
if (UNLIKELY(kind != ENTRY_INSTR))
return kind == 0 ? FD_ERR_UD : FD_ERR_PARTIAL;
goto direct;
}
uint8_t prefix_rep = prefixes[PF_REP];
if (UNLIKELY(off >= len))
return FD_ERR_PARTIAL;
@@ -200,8 +219,11 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
// VEX (C4/C5) or EVEX (62)
if (UNLIKELY(off + 1 >= len))
return FD_ERR_PARTIAL;
if (mode == DECODE_32 && (buffer[off + 1] & 0xc0) != 0xc0)
goto skipvex;
if (UNLIKELY(mode == DECODE_32 && buffer[off + 1] < 0xc0)) {
off++;
table_idx = table_walk(table_idx, 0, &kind);
goto direct;
}
// VEX/EVEX + 66/F3/F2/REX will #UD.
// Note: REX is also here only respected if it immediately precedes the
@@ -212,7 +234,7 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
uint8_t byte = buffer[off + 1];
if (vex_prefix == 0xc5) // 2-byte VEX
{
opcode_escape = 1 | 4; // 4 is table index with VEX, 0f escape
opcode_escape = 1;
prefix_rex = byte & 0x80 ? 0 : PREFIX_REXR;
}
else // 3-byte VEX or EVEX
@@ -224,7 +246,7 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
{
if (byte & 0x08) // Bit 3 of opcode_escape must be clear.
return FD_ERR_UD;
opcode_escape = (byte & 0x07) | 8; // 8 is table index with EVEX
opcode_escape = (byte & 0x07);
_Static_assert(PREFIX_REXRR == 0x10, "wrong REXRR value");
if (mode == DECODE_64)
prefix_rex |= (byte & PREFIX_REXRR) ^ PREFIX_REXRR;
@@ -233,7 +255,13 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
{
if (byte & 0x1c) // Bits 4:2 of opcode_escape must be clear.
return FD_ERR_UD;
opcode_escape = (byte & 0x03) | 4; // 4 is table index with VEX
opcode_escape = (byte & 0x03); // 4 is table index with VEX
}
if (UNLIKELY(opcode_escape == 0)) {
int prefix_len = vex_prefix == 0x62 ? 4 : 3;
// Pretend to decode the prefix plus one opcode byte.
return off + prefix_len > len ? FD_ERR_PARTIAL : FD_ERR_UD;
}
// Load third byte of VEX prefix
@@ -267,8 +295,6 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
vexl = byte & 0x04 ? 1 : 0;
off += 0xc7 - vex_prefix; // 3 for c4, 2 for c5
}
skipvex:;
}
table_idx = table_walk(table_idx, opcode_escape, &kind);
@@ -304,6 +330,7 @@ fd_decode(const uint8_t* buffer, size_t len_sz, int mode_int, uintptr_t address,
if (UNLIKELY(kind != ENTRY_INSTR))
return kind == 0 ? FD_ERR_UD : FD_ERR_PARTIAL;
direct:;
static _Alignas(16) const struct InstrDesc descs[] = {
#define FD_DECODE_TABLE_DESCS
#include <fadec-decode-private.inc>

View File

@@ -324,6 +324,8 @@ class InstrDesc(NamedTuple):
class EntryKind(Enum):
NONE = 0
PREFIX = 8
ESCAPE = 7
INSTR = 1
WEAKINSTR = 9
TABLE256 = 2
@@ -334,7 +336,7 @@ class EntryKind(Enum):
TABLE_ROOT = -1
@property
def is_table(self):
return self != EntryKind.INSTR and self != EntryKind.WEAKINSTR
return self != EntryKind.INSTR and self != EntryKind.WEAKINSTR and self != EntryKind.PREFIX
opcode_regex = re.compile(
r"^(?:(?P<prefixes>(?P<vex>E?VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
@@ -392,6 +394,8 @@ def verifyOpcodeDesc(opcode, desc):
raise Exception(f"unescaped opcode has L specifier {opcode}, {desc}")
if opcode.escape == 0 and opcode.rexw is not None:
raise Exception(f"unescaped opcode has W specifier {opcode}, {desc}")
if opcode.escape == 0 and opcode.vex:
raise Exception(f"VEX opcode without escape {opcode}, {desc}")
if opcode.vex and opcode.prefix not in ("NP", "66", "F2", "F3"):
raise Exception(f"VEX/EVEX must have mandatory prefix {opcode}, {desc}")
if opcode.vexl == "IG" and desc.dynsizes() - {OpKind.SZ_OP}:
@@ -456,11 +460,12 @@ def verifyOpcodeDesc(opcode, desc):
raise Exception(f"memory size {opsz} != {tupsz} {opcode}, {desc}")
class Trie:
KIND_ORDER = (EntryKind.TABLE_ROOT, EntryKind.TABLE256,
KIND_ORDER = (EntryKind.TABLE_ROOT, EntryKind.ESCAPE, EntryKind.TABLE256,
EntryKind.TABLE_PREFIX, EntryKind.TABLE16,
EntryKind.TABLE8E, EntryKind.TABLE_VEX)
TABLE_LENGTH = {
EntryKind.TABLE_ROOT: 16,
EntryKind.TABLE_ROOT: 256,
EntryKind.ESCAPE: 8,
EntryKind.TABLE256: 256,
EntryKind.TABLE_PREFIX: 4,
EntryKind.TABLE16: 16,
@@ -486,8 +491,22 @@ class Trie:
return elem[0], new_num
def _transform_opcode(self, opc):
troot = [opc.escape | opc.vex << 2]
t256 = [opc.opc + i for i in range(8 if opc.extended and not opc.opcext else 1)]
topc = [opc.opc + i for i in range(8 if opc.extended and not opc.opcext else 1)]
if opc.escape == 0 and opc.opc in (0xc4, 0xc5, 0x62):
assert opc.prefix is None
assert opc.opcext is None
assert opc.modreg == (None, "m")
assert opc.rexw is None
assert opc.vexl is None
# We do NOT encode /m, this is handled by prefix code.
# Order must match KIND_ORDER.
return topc, [0], None, None, None, None, None
elif opc.escape == 0:
troot, tescape, topc = topc, None, None
else:
troot = [[0x0f], [0xc4, 0xc5], [0x62]][opc.vex]
tescape = [opc.escape]
tprefix, t16, t8e, tvex = None, None, None, None
if opc.prefix == "NFx":
tprefix = [0, 1]
@@ -510,7 +529,7 @@ class Trie:
vexl = {"0": [0], "12": [1<<1, 2<<1], "2": [2<<1], "IG": [0, 1<<1, 2<<1, 3<<1]}[opc.vexl or "IG"]
tvex = list(map(sum, product(rexw, vexl)))
# Order must match KIND_ORDER.
return troot, t256, tprefix, t16, t8e, tvex
return troot, tescape, topc, tprefix, t16, t8e, tvex
def add_opcode(self, opcode, descidx, root_idx, weak=False):
opcode = self._transform_opcode(opcode)
@@ -546,6 +565,11 @@ class Trie:
elif not weak:
raise Exception(f"redundant non-weak {opcode}")
def add_prefix(self, byte, prefix, root_idx):
if self.trie[0][root_idx] is None:
self.trie[0][root_idx] = EntryKind.TABLE_ROOT, self._add_table(EntryKind.TABLE_ROOT)
self.trie[self.trie[0][root_idx][1]][byte] = EntryKind.PREFIX, prefix
def deduplicate(self):
synonyms = {}
for kind in self.KIND_ORDER[::-1]:
@@ -627,6 +651,21 @@ def decode_table(entries, args):
modes = args.modes
trie = Trie(root_count=len(modes))
for i, mode in enumerate(modes):
# Magic values must match PF_* enum in decode.c.
trie.add_prefix(0x66, 0xfffa, i)
trie.add_prefix(0x67, 0xfffb, i)
trie.add_prefix(0xf0, 0xfffc, i)
trie.add_prefix(0xf2, 0xfffd, i)
trie.add_prefix(0xf3, 0xfffd, i)
trie.add_prefix(0x64, 0xfff9, i)
trie.add_prefix(0x65, 0xfff9, i)
for seg in (0x26, 0x2e, 0x36, 0x3e):
trie.add_prefix(seg, 0xfff8 + (mode <= 32), i)
if mode > 32:
for rex in range(0x40, 0x50):
trie.add_prefix(rex, 0xfffe, i)
mnems, descs, desc_map = set(), [], {}
for weak, opcode, desc in entries:
ign66 = opcode.prefix in ("NP", "66", "F2", "F3")