#!/usr/bin/python3 from binascii import unhexlify from collections import OrderedDict, defaultdict from copy import copy from enum import Enum, IntEnum from itertools import accumulate import struct import sys def bitstruct(name, fields): names, sizes = zip(*(field.split(":") for field in fields)) sizes = tuple(map(int, sizes)) offsets = (0,) + tuple(accumulate(sizes)) class __class: def __init__(self, **kwargs): for name in names: setattr(self, name, kwargs.get(name, 0)) def _encode(self): return sum((getattr(self, name) & ((1 << size) - 1)) << offset for name, size, offset in zip(names, sizes, offsets)) __class.__name__ = name __class._encode_size = offsets[-1] return __class InstrFlags = bitstruct("InstrFlags", [ "modrm_idx:2", "modreg_idx:2", "vexreg_idx:2", "zeroreg_idx:2", "operand_sizes:8", "imm_idx:2", "imm_size:2", "imm_control:3", "imm_byte:1", "gp_size_8:1", "gp_size_def64:1", "gp_instr_width:1", "gp_fixed_operand_size:3", ]) assert InstrFlags._encode_size <= 32 ENCODINGS = { "NP": InstrFlags(), "M": InstrFlags(modrm_idx=0^3), "M1": InstrFlags(modrm_idx=0^3, imm_idx=1^3, imm_control=1), "MI": InstrFlags(modrm_idx=0^3, imm_idx=1^3, imm_control=3), "MR": InstrFlags(modrm_idx=0^3, modreg_idx=1^3), "RM": InstrFlags(modrm_idx=1^3, modreg_idx=0^3), "RMA": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, zeroreg_idx=2^3), "MRI": InstrFlags(modrm_idx=0^3, modreg_idx=1^3, imm_idx=2^3, imm_control=3), "RMI": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, imm_idx=2^3, imm_control=3), "I": InstrFlags(imm_idx=0^3, imm_control=3), "IA": InstrFlags(zeroreg_idx=0^3, imm_idx=1^3, imm_control=3), "O": InstrFlags(modreg_idx=0^3), "OI": InstrFlags(modreg_idx=0^3, imm_idx=1^3, imm_control=3), "OA": InstrFlags(modreg_idx=0^3, zeroreg_idx=1^3), "AO": InstrFlags(modreg_idx=1^3, zeroreg_idx=0^3), "D": InstrFlags(imm_idx=0^3, imm_control=4), "FD": InstrFlags(zeroreg_idx=0^3, imm_idx=1^3, imm_control=2), "TD": InstrFlags(zeroreg_idx=1^3, imm_idx=0^3, imm_control=2), "RVM": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3), "RVMI": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=3, imm_byte=1), "RVMR": InstrFlags(modrm_idx=2^3, modreg_idx=0^3, vexreg_idx=1^3, imm_idx=3^3, imm_control=5, imm_byte=1), "RMV": InstrFlags(modrm_idx=1^3, modreg_idx=0^3, vexreg_idx=2^3), "VM": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3), "VMI": InstrFlags(modrm_idx=1^3, vexreg_idx=0^3, imm_idx=2^3, imm_control=3, imm_byte=1), "MVR": InstrFlags(modrm_idx=0^3, modreg_idx=2^3, vexreg_idx=1^3), } OPKIND_LOOKUP = { "-": (0, 0), "IMM": (2, 0), "IMM8": (1, 0), "IMM16": (1, 1), "IMM32": (1, 2), "GP": (2, 0), "GP8": (1, 0), "GP16": (1, 1), "GP32": (1, 2), "GP64": (1, 3), "XMM": (3, 0), "XMM8": (1, 0), "XMM16": (1, 1), "XMM32": (1, 2), "XMM64": (1, 3), "XMM128": (1, 4), "XMM256": (1, 5), "SREG": (0, 0), "FPU": (0, 0), } def parse_desc(desc, ignore_flag): desc = desc.split() if ignore_flag in desc[6:]: return None fixed_opsz = set() opsizes = 0 for i, opkind in enumerate(desc[1:5]): enc_size, fixed_size = OPKIND_LOOKUP[opkind] if enc_size == 1: fixed_opsz.add(fixed_size) opsizes |= enc_size << 2 * i flags = copy(ENCODINGS[desc[0]]) flags.operand_sizes = opsizes if fixed_opsz: flags.gp_fixed_operand_size = next(iter(fixed_opsz)) # Miscellaneous Flags if "DEF64" in desc[6:]: flags.gp_size_def64 = 1 if "SIZE_8" in desc[6:]: flags.gp_size_8 = 1 if "INSTR_WIDTH" in desc[6:]: flags.gp_instr_width = 1 if "IMM_8" in desc[6:]: flags.imm_byte = 1 return desc[5], flags._encode() class EntryKind(Enum): NONE = 0 INSTR = 1 TABLE256 = 2 TABLE8 = 3 TABLE72 = 4 TABLE_PREFIX = 5 @property def table_length(self): return { EntryKind.INSTR: 0, EntryKind.TABLE256: 256, EntryKind.TABLE8: 8, EntryKind.TABLE72: 72, EntryKind.TABLE_PREFIX: 16 }[self] import re opcode_regex = re.compile(r"^(?P(?PVEX\.)?(?PNP|66|F2|F3)\.(?PW[01]\.)?(?PL[01]\.)?)?(?P(?:[0-9a-f]{2})+)(?P//?[0-7]|//[c-f][0-9a-f])?(?P\+)?$") def parse_opcode(opcode_string): """ Parse opcode string into list of type-index tuples. """ match = opcode_regex.match(opcode_string) if match is None: raise Exception("invalid opcode: '%s'" % opcode_string) extended = match.group("extended") is not None opcode = [(EntryKind.TABLE256, x) for x in unhexlify(match.group("opcode"))] opcext = match.group("modrm") if opcext: if opcext[1] == "/": opcext = int(opcext[2:], 16) assert (0 <= opcext <= 7) or (0xc0 <= opcext <= 0xff) if opcext >= 0xc0: opcext -= 0xb8 opcode.append((EntryKind.TABLE72, opcext)) else: opcode.append((EntryKind.TABLE8, int(opcext[1:], 16))) if match.group("prefixes"): assert not extended legacy = {"NP": 0, "66": 1, "F3": 2, "F2": 3}[match.group("legacy")] entry = legacy | ((1 << 3) if match.group("vex") else 0) if match.group("vexl"): print("ignored mandatory VEX.L prefix for:", opcode_string) rexw = match.group("rexw") if not rexw: return [tuple(opcode) + ((EntryKind.TABLE_PREFIX, entry),), tuple(opcode) + ((EntryKind.TABLE_PREFIX, entry | (1 << 2)),)] entry |= (1 << 2) if "W1" in rexw else 0 return [tuple(opcode) + ((EntryKind.TABLE_PREFIX, entry),)] if not extended: return [tuple(opcode)] last_type, last_index = opcode[-1] assert last_type in (EntryKind.TABLE256, EntryKind.TABLE72) assert last_index & 7 == 0 common_prefix = tuple(opcode[:-1]) return [common_prefix + ((last_type, last_index + i),) for i in range(8)] class Table: def __init__(self): self.data = OrderedDict() self.data["root"] = (EntryKind.TABLE256, [None] * 256) self.instrs = {} def compile(self, mnemonics_lut): offsets = {} annotations = {} currentOffset = 0 stats = defaultdict(int) for name, (kind, _) in self.data.items(): annotations[currentOffset] = "%s(%d)" % (name, kind.value) offsets[name] = currentOffset stats[kind] += 1 if kind.table_length: currentOffset += kind.table_length * 2 else: currentOffset += 6 currentOffset = (currentOffset + 7) & ~7 assert currentOffset < 0x10000 data = b"" for name, (kind, value) in self.data.items(): if len(data) < offsets[name]: data += b"\0" * (offsets[name] - len(data)) assert len(data) == offsets[name] if kind == EntryKind.INSTR: mnemonicIdx = mnemonics_lut[value[0]] data += struct.pack("