parseinstrs: Simplify operand kind parsing

This commit is contained in:
Alexis Engelke
2021-01-22 10:55:11 +01:00
parent bd611902b0
commit 62018556a1
2 changed files with 22 additions and 50 deletions

View File

@@ -4,6 +4,7 @@ import argparse
from collections import OrderedDict, defaultdict, namedtuple, Counter
from enum import Enum
from itertools import product
import re
import struct
from typing import NamedTuple, FrozenSet, List, Tuple, Union, Optional, ByteString
@@ -76,14 +77,15 @@ ENCODINGS = {
"MVR": InstrFlags(modrm=1, modrm_idx=0^3, modreg_idx=2^3, vexreg_idx=1^3),
}
OPKIND_REGEX = re.compile(r"^([A-Z]+)([0-9]+)?$")
OPKIND_DEFAULTS = {"GP": -1, "IMM": -1, "SEG": -1, "MEM": -1, "XMM": -2, "MMX": 8, "FPU": 10}
OPKIND_KINDS = ("IMM", "MEM", "GP", "MMX", "XMM", "SEG", "FPU", "MEM", "MASK", "CR", "DR", "TMM", "BND")
class OpKind(NamedTuple):
size: int
kind: str
SZ_OP = -1
SZ_VEC = -2
K_MEM = "mem"
K_IMM = "imm"
def abssize(self, opsz=None, vecsz=None):
res = opsz if self.size == self.SZ_OP else \
@@ -91,48 +93,17 @@ class OpKind(NamedTuple):
if res is None:
raise Exception("unspecified operand size")
return res
OPKINDS = {
# sizeidx (0, fixedsz, opsz, vecsz), fixedsz (log2), regtype
"IMM": OpKind(OpKind.SZ_OP, OpKind.K_IMM),
"IMM8": OpKind(1, OpKind.K_IMM),
"IMM16": OpKind(2, OpKind.K_IMM),
"IMM32": OpKind(4, OpKind.K_IMM),
"IMM64": OpKind(8, OpKind.K_IMM),
"GP": OpKind(OpKind.SZ_OP, "GP"),
"GP8": OpKind(1, "GP"),
"GP16": OpKind(2, "GP"),
"GP32": OpKind(4, "GP"),
"GP64": OpKind(8, "GP"),
"MMX": OpKind(8, "MMX"),
"XMM": OpKind(OpKind.SZ_VEC, "XMM"),
"XMM8": OpKind(1, "XMM"),
"XMM16": OpKind(2, "XMM"),
"XMM32": OpKind(4, "XMM"),
"XMM64": OpKind(8, "XMM"),
"XMM128": OpKind(16, "XMM"),
"XMM256": OpKind(32, "XMM"),
"SEG": OpKind(OpKind.SZ_OP, "SEG"),
"SEG16": OpKind(2, "SEG"),
"FPU": OpKind(10, "FPU"),
"MEM": OpKind(OpKind.SZ_OP, OpKind.K_MEM),
"MEMV": OpKind(OpKind.SZ_VEC, OpKind.K_MEM),
"MEMZ": OpKind(0, OpKind.K_MEM),
"MEM8": OpKind(1, OpKind.K_MEM),
"MEM16": OpKind(2, OpKind.K_MEM),
"MEM32": OpKind(4, OpKind.K_MEM),
"MEM64": OpKind(8, OpKind.K_MEM),
"MEM128": OpKind(16, OpKind.K_MEM),
"MEM256": OpKind(32, OpKind.K_MEM),
"MEM512": OpKind(64, OpKind.K_MEM),
"MASK8": OpKind(1, "MASK"),
"MASK16": OpKind(2, "MASK"),
"MASK32": OpKind(4, "MASK"),
"MASK64": OpKind(8, "MASK"),
"BND": OpKind(0, "BND"),
"CR": OpKind(0, "CR"),
"DR": OpKind(0, "DR"),
}
@classmethod
def parse(cls, op):
op = {"MEMZ": "MEM0", "MEMV": "XMM"}.get(op, op)
match = OPKIND_REGEX.match(op)
if not match:
raise Exception(f"invalid opkind str: {op}")
kind, size = match.groups()
size = int(size) // 8 if size else OPKIND_DEFAULTS.get(kind, 0)
if kind not in OPKIND_KINDS:
raise Exception(f"invalid opkind kind: {op}")
return cls(size, kind)
class InstrDesc(NamedTuple):
mnemonic: str
@@ -150,7 +121,7 @@ class InstrDesc(NamedTuple):
@classmethod
def parse(cls, desc):
desc = desc.split()
operands = tuple(OPKINDS[op] for op in desc[1:5] if op != "-")
operands = tuple(OpKind.parse(op) for op in desc[1:5] if op != "-")
return cls(desc[5], desc[0], operands, frozenset(desc[6:]))
def encode(self, ign66, modrm):
@@ -189,7 +160,7 @@ class InstrDesc(NamedTuple):
extraflags["ign66"] = 1
if flags.imm_control >= 4:
imm_op = next(op for op in self.operands if op.kind == OpKind.K_IMM)
imm_op = self.operands[flags.imm_idx^3]
if ("IMM_8" in self.flags or imm_op.size == 1 or
(imm_op.size == OpKind.SZ_OP and "SIZE_8" in self.flags)):
extraflags["imm_control"] = flags.imm_control | 1
@@ -229,7 +200,6 @@ class TrieEntry(NamedTuple):
def instr(cls, descidx):
return cls(EntryKind.INSTR, (), descidx)
import re
opcode_regex = re.compile(
r"^(?:(?P<prefixes>(?P<vex>VEX\.)?(?P<legacy>NP|66|F2|F3|NFx)\." +
r"(?:W(?P<rexw>[01]|IG)\.)?(?:L(?P<vexl>[01]|IG)\.)?))?" +
@@ -583,6 +553,8 @@ def encode_table(entries):
for ot, op in zip(ots, desc.operands):
if ot == "m":
tys.append(0xf)
elif ot in "io":
tys.append(0)
elif op.kind == "GP":
if (desc.mnemonic == "MOVSX" or desc.mnemonic == "MOVZX" or
opsize == 8):
@@ -591,9 +563,9 @@ def encode_table(entries):
tys.append(1)
else:
tys.append({
"imm": 0, "SEG": 3, "FPU": 4, "MMX": 5, "XMM": 6,
"SEG": 3, "FPU": 4, "MMX": 5, "XMM": 6,
"BND": 8, "CR": 9, "DR": 10,
}.get(op.kind, -1))
}[op.kind])
tys_i = sum(ty << (4*i) for i, ty in enumerate(tys))
opc_s = hex(opc_i) + opc_flags + prefix[1]