encode2: Add new encoder API, one func per instr.

This is an *experimental* (read: unstable) API which exposes encoding
functionality as one function per instruction. This makes the encoding
process itself significantly faster, at the cost of a much larger binary
size (~1 MiB of code, no data) and much higher compilation time.
This commit is contained in:
Alexis Engelke
2022-02-20 17:21:04 +01:00
parent 003c7ca750
commit 6b8c2968c1
7 changed files with 646 additions and 6 deletions

View File

@@ -687,10 +687,293 @@ def encode_table(entries, args):
mnem_tab = "".join(f"FE_MNEMONIC({m},{i})\n" for i, m in enumerate(mnem_list))
return mnem_tab, descs
def encode2_table(entries, args):
mnemonics = defaultdict(list)
for weak, opcode, desc in entries:
if "I64" in desc.flags or desc.mnemonic[:9] == "RESERVED_":
continue
opsizes = {8} if "SZ8" in desc.flags else {16, 32, 64}
hasvex, vecsizes = False, {128}
if opcode.vex:
hasvex, vecsizes = True, {128, 256}
if opcode.prefix in ("66", "F2", "F3") and "U66" not in desc.flags:
opsizes -= {16}
if opcode.vexl == "IG":
vecsizes = {0}
elif opcode.vexl:
vecsizes -= {128 if opcode.vexl == "1" else 256}
if opcode.rexw == "IG":
opsizes = {0}
elif opcode.rexw:
opsizes -= {32 if opcode.rexw == "1" else 64}
if "I66" in desc.flags:
opsizes -= {16}
if "D64" in desc.flags:
opsizes -= {32}
if "SZ8" not in desc.flags and "INSTR_WIDTH" not in desc.flags and all(op.size != OpKind.SZ_OP for op in desc.operands):
opsizes = {0}
if "VSIB" not in desc.flags and all(op.size != OpKind.SZ_VEC for op in desc.operands):
vecsizes = {0} # for VEX-encoded general-purpose instructions.
if "ENC_NOSZ" in desc.flags:
opsizes, vecsizes = {0}, {0}
# Where to put the operand size in the mnemonic
separate_opsize = "ENC_SEPSZ" in desc.flags
prepend_opsize = max(opsizes) > 0 and not separate_opsize
prepend_vecsize = hasvex and max(vecsizes) > 0 and not separate_opsize
if "F64" in desc.flags:
opsizes = {64}
prepend_opsize = False
modrm_type = opcode.modreg[1] if opcode.modreg else "rm"
optypes_base = desc.optype_str()
optypes = {optypes_base.replace("M", t) for t in modrm_type}
prefixes = [("", "")]
if "LOCK" in desc.flags:
prefixes.append(("LOCK_", "LOCK"))
if "ENC_REP" in desc.flags:
prefixes.append(("REP_", "F3"))
if "ENC_REPCC" in desc.flags:
prefixes.append(("REPNZ_", "F2"))
prefixes.append(("REPZ_", "F3"))
for opsize, vecsize, prefix, ots in product(opsizes, vecsizes, prefixes, optypes):
if prefix[1] == "LOCK" and ots[0] != "m":
continue
spec_opcode = opcode
if prefix[1]:
spec_opcode = spec_opcode._replace(prefix=prefix[1])
if opsize == 64 and "D64" not in desc.flags and "F64" not in desc.flags:
spec_opcode = spec_opcode._replace(rexw="1")
if vecsize == 256:
spec_opcode = spec_opcode._replace(vexl="1")
# Construct mnemonic name
mnem_name = {"MOVABS": "MOV", "XCHG_NOP": "XCHG"}.get(desc.mnemonic, desc.mnemonic)
name = prefix[0] + mnem_name
if prepend_opsize and not ("D64" in desc.flags and opsize == 64):
name += f"_{opsize}"[name[-1] not in "0123456789":]
if prepend_vecsize:
name += f"_{vecsize}"[name[-1] not in "0123456789":]
for ot, op in zip(ots, desc.operands):
name += ot.replace("o", "")
if separate_opsize:
name += f"{op.abssize(opsize//8, vecsize//8)*8}"
mnemonics[name, opsize, ots].append((spec_opcode, desc))
enc_decls, enc_code = "", ""
for (mnem, opsize, ots), variants in mnemonics.items():
dedup = OrderedDict()
for i, (opcode, desc) in enumerate(variants):
PRIO = ["O", "OA", "AO", "AM", "MA", "IA", "OI"]
enc_prio = PRIO.index(desc.encoding) if desc.encoding in PRIO else len(PRIO)
unique = 0 if desc.encoding != "S" else i
key = desc.imm_size(opsize//8), enc_prio, unique
if key not in dedup:
dedup[key] = opcode, desc
if desc.encoding == "S":
print(mnem, key, desc, dedup)
variants = [dedup[k] for k in sorted(dedup.keys())]
max_imm_size = max(k[0] for k in dedup.keys())
supports_high_regs = []
if variants[0][1].mnemonic in ("MOVSX", "MOVZX") or opsize == 8:
# Should be the same for all variants
desc = variants[0][1]
for i, (ot, op) in enumerate(zip(ots, desc.operands)):
if ot == "r" and op.kind == "GP" and op.abssize(opsize//8) == 1:
supports_high_regs.append(i)
supports_vsib = "VSIB" in variants[0][1].flags
if len({tuple(op.kind for op in v[1].operands) for v in variants}) > 1:
raise Exception(f"ambiguous operand kinds for {mnem}")
OPKIND_LUT = {"FPU": "ST", "SEG": "SREG", "MMX": "MM"}
reg_tys = [OPKIND_LUT.get(op.kind, op.kind) for op in variants[0][1].operands]
fnname = f"fe64_{mnem}{'_impl' if supports_high_regs else ''}"
op_tys = [{
"i": f"int{max_imm_size*8 if max_imm_size != 3 else 32}_t",
"a": "uintptr_t",
"r": f"FeReg{reg_ty if i not in supports_high_regs else 'GPLH'}",
"m": "FeMem" if not supports_vsib else "FeMemV",
"o": "const void*",
}[ot] for i, (ot, reg_ty) in enumerate(zip(ots, reg_tys))]
fn_opargs = "".join(f", {ty} op{i}" for i, ty in enumerate(op_tys))
fn_sig = f"unsigned {fnname}(uint8_t* buf, int flags{fn_opargs})"
enc_decls += f"{fn_sig};\n"
if supports_high_regs:
enc_decls += f"#define fe64_{mnem}(buf, flags"
enc_decls += "".join(f", op{i}" for i in range(len(op_tys)))
enc_decls += f") {fnname}(buf, flags"
enc_decls += "".join(f", FE_MAKE_GPLH(op{i})" if i in supports_high_regs else f", op{i}" for i in range(len(op_tys)))
enc_decls += f")\n"
code = f"{fn_sig} {{\n"
code += " unsigned idx = 0, rex = 0, memoff;\n"
if max_imm_size or "a" in ots:
code += " int64_t imm; unsigned imm_size;\n"
code += " (void) flags; (void) memoff;\n"
neednext = True
for i, (opcode, desc) in enumerate(variants):
if not neednext:
break
if i > 0:
code += f"\nnext{i-1}:\n"
neednext = False
imm_size = desc.imm_size(opsize//8)
flags = ENCODINGS[desc.encoding]
# Select usable encoding.
if desc.encoding == "S":
# Segment encoding is weird.
code += f" if (op_reg_idx(op0)!={(opcode.opc>>3)&0x7:#x}) goto next{i};\n"
neednext = True
if desc.mnemonic == "XCHG_NOP" and opsize == 32:
# XCHG eax, eax must not be encoded as 90 -- that'd be NOP.
code += f" if (op_reg_idx(op0)==0&&op_reg_idx(op1)==0) goto next{i};\n"
neednext = True
if flags.zeroreg_idx:
code += f" if (op_reg_idx(op{flags.zeroreg_idx^3})!={flags.zeroreg_val}) goto next{i};\n"
neednext = True
if flags.imm_control:
if flags.imm_control != 3:
code += f" imm = (int64_t) op{flags.imm_idx^3};\n"
else:
code += f" imm = op_reg_idx(op{flags.imm_idx^3}) << 4;\n"
code += f" imm_size = {imm_size};\n"
if flags.imm_control == 1:
code += f" if (imm != 1) goto next{i};\n"
neednext = True
if flags.imm_control == 2:
code += " imm_size = flags & FE_ADDR32 ? 4 : 8;\n"
code += " if (imm_size == 4) imm = (int32_t) imm;\n"
if imm_size < max_imm_size and 2 <= flags.imm_control < 6:
code += f" if (!op_imm_n(imm, imm_size)) goto next{i};\n"
neednext = True
if flags.imm_control == 6:
# idx is subtracted below.
code += f" imm -= (int64_t) buf + imm_size;\n"
if i != len(variants) - 1: # only Jcc+JMP
code += f" if (flags&FE_JMPL) goto next{i};\n"
# assume one-byte opcode without escape/prefixes
code += f" if (!op_imm_n(imm-1, imm_size)) goto next{i};\n"
neednext = True
if opcode.vex:
rexw, rexr, rexx, rexb = 0x8000, 0x80, 0x40, 0x20
else:
rexw, rexr, rexx, rexb = 0x48, 0x44, 0x42, 0x41
if not opcode.vex:
for i in supports_high_regs:
code += f" if (op_reg_idx(op{i}) >= 4 && op_reg_idx(op{i}) <= 15) rex = 0x40;\n"
if opcode.rexw == "1":
code += f" rex |= {rexw:#x};\n"
if flags.modrm_idx:
ismem = ots[flags.modrm_idx^3] == "m"
if ismem:
code += f" if (op_mem_base(op{flags.modrm_idx^3})&8) rex |= {rexb:#x};\n"
code += f" if (op_mem_idx(op{flags.modrm_idx^3})&8) rex |= {rexx:#x};\n"
else:
if desc.operands[flags.modrm_idx^3].kind in ("GP", "XMM"):
code += f" if (op_reg_idx(op{flags.modrm_idx^3})&8) rex |= {rexb:#x};\n"
if flags.modreg_idx:
if desc.operands[flags.modreg_idx^3].kind in ("GP", "XMM", "CR", "DR"):
code += f" if (op_reg_idx(op{flags.modreg_idx^3})&8) rex |= {rexr:#x};\n"
elif flags.modreg_idx: # O encoding
if desc.operands[flags.modreg_idx^3].kind in ("GP", "XMM"):
code += f" if (op_reg_idx(op{flags.modreg_idx^3})&8) rex |= {rexb:#x};\n"
for i in supports_high_regs:
code += f" if (rex && op_reg_gph(op{i})) return 0;\n"
if "m" in ots or "USEG" in desc.flags:
code += " if (UNLIKELY(flags & FE_SEG_MASK)) buf[idx++] = enc_seg(flags);\n"
if "m" in ots or "U67" in desc.flags:
code += " if (UNLIKELY(flags & FE_ADDR32)) buf[idx++] = 0x67;\n"
if opcode.vex:
ppl = ["NP", "66", "F3", "F2"].index(opcode.prefix)
ppl |= 4 if opcode.vexl == "1" else 0
mayvex2 = opcode.rexw != "1" and opcode.escape == 1
if mayvex2:
code += " if (!(rex&0x8060)) {\n"
code += " buf[idx++] = 0xc5;\n"
code += " rex ^= 0x80;\n"
code += " } else {\n"
code += " buf[idx++] = 0xc4;\n"
code += f" buf[idx++] = {0xe0+opcode.escape:#x}^rex;\n"
code += " rex >>= 8;\n"
if mayvex2:
code += " }\n"
vexop = 0
if flags.vexreg_idx:
vexop = f"op_reg_idx(op{flags.vexreg_idx^3})"
code += f" buf[idx++] = {ppl}|rex|(({vexop}^15)<<3);\n"
else:
if opsize == 16 or opcode.prefix == "66":
code += " buf[idx++] = 0x66;\n"
if opcode.prefix in ("F2", "F3"):
code += f" buf[idx++] = 0x{opcode.prefix};\n"
if opcode.prefix == "LOCK":
code += f" buf[idx++] = 0xF0;\n"
code += f" if (rex) buf[idx++] = rex;\n"
if opcode.escape:
code += f" buf[idx++] = 0x0F;\n"
if opcode.escape == 2:
code += f" buf[idx++] = 0x38;\n"
elif opcode.escape == 3:
code += f" buf[idx++] = 0x3A;\n"
code += f" buf[idx++] = {opcode.opc:#x};\n"
if opcode.opcext:
code += f" buf[idx++] = {opcode.opcext:#x};\n"
if flags.modrm:
modrm = f"op{flags.modrm_idx^3}"
if flags.modreg_idx:
modreg = f"op_reg_idx(op{flags.modreg_idx^3})"
else:
modreg = int(opcode.modreg[0]) if opcode.modreg else 0
if ismem:
imm_size_expr = "imm_size" if flags.imm_control >= 2 else 0
memfn = "enc_mem_vsib" if "VSIB" in desc.flags else "enc_mem"
code += f" memoff = {memfn}(buf, idx, {modrm}, {modreg}, {imm_size_expr}, 0);\n"
code += f" if (!memoff) return 0;\n idx += memoff;\n"
else:
modrm = f"op_reg_idx({modrm})"
code += f" buf[idx++] = 0xC0|(({modreg}&7)<<3)|({modrm}&7);\n"
elif flags.modrm_idx:
code += f" buf[idx-1] |= op_reg_idx(op{flags.modrm_idx^3}) & 7;\n"
if flags.imm_control >= 2:
if flags.imm_control == 6:
code += f" imm -= idx;\n"
code += f" if (enc_imm(buf+idx, imm, imm_size)) return 0;\n"
code += f" idx += imm_size;\n"
code += f" return idx;\n"
if neednext:
code += f"next{len(variants)-1}: return 0;\n"
code += "}\n"
enc_code += code
return enc_decls, enc_code
if __name__ == "__main__":
generators = {
"decode": decode_table,
"encode": encode_table,
"encode2": encode2_table,
}
parser = argparse.ArgumentParser()