reccmp: Sanitize performance (and more) (#654)

This commit is contained in:
MS 2024-03-10 14:49:45 -04:00 committed by GitHub
parent c28c2aeb52
commit 5eb74c06fd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 172 additions and 124 deletions

View file

@ -0,0 +1,27 @@
# Duplicates removed, according to the mnemonics capstone uses.
# e.g. je and jz are the same instruction. capstone uses je.
# See: /arch/X86/X86GenAsmWriter.inc in the capstone repo.
JUMP_MNEMONICS = {
"ja",
"jae",
"jb",
"jbe",
"jcxz", # unused?
"je",
"jecxz",
"jg",
"jge",
"jl",
"jle",
"jmp",
"jne",
"jno",
"jnp",
"jns",
"jo",
"jp",
"js",
}
# Guaranteed to be a single operand.
SINGLE_OPERAND_INSTS = {"push", "call", *JUMP_MNEMONICS}

View file

@ -12,10 +12,15 @@
from collections import namedtuple from collections import namedtuple
from isledecomp.bin import InvalidVirtualAddressError from isledecomp.bin import InvalidVirtualAddressError
from capstone import Cs, CS_ARCH_X86, CS_MODE_32 from capstone import Cs, CS_ARCH_X86, CS_MODE_32
from .const import JUMP_MNEMONICS, SINGLE_OPERAND_INSTS
disassembler = Cs(CS_ARCH_X86, CS_MODE_32) disassembler = Cs(CS_ARCH_X86, CS_MODE_32)
ptr_replace_regex = re.compile(r"(?P<data_size>\w+) ptr \[(?P<addr>0x[0-9a-fA-F]+)\]") ptr_replace_regex = re.compile(r"\[(0x[0-9a-f]+)\]")
# For matching an immediate value on its own.
# Preceded by start-of-string (first operand) or comma-space (second operand)
immediate_replace_regex = re.compile(r"(?:^|, )(0x[0-9a-f]+)")
DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str") DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str")
@ -30,10 +35,6 @@ def from_hex(string: str) -> Optional[int]:
return None return None
def get_float_size(size_str: str) -> int:
return 8 if size_str == "qword" else 4
class ParseAsm: class ParseAsm:
def __init__( def __init__(
self, self,
@ -94,14 +95,41 @@ def replace(self, addr: int) -> str:
self.replacements[addr] = placeholder self.replacements[addr] = placeholder
return placeholder return placeholder
def hex_replace_always(self, match: re.Match) -> str:
"""If a pointer value was matched, always insert a placeholder"""
value = int(match.group(1), 16)
return match.group(0).replace(match.group(1), self.replace(value))
def hex_replace_relocated(self, match: re.Match) -> str:
"""For replacing immediate value operands. We only want to
use the placeholder if we are certain that this is a valid address.
We can check the relocation table to find out."""
value = int(match.group(1), 16)
if self.is_relocated(value):
return match.group(0).replace(match.group(1), self.replace(value))
return match.group(0)
def hex_replace_float(self, match: re.Match) -> str:
"""Special case for replacements on float instructions.
If the pointer is a float constant, read it from the binary."""
value = int(match.group(1), 16)
# If we can find a variable name for this pointer, use it.
placeholder = self.lookup(value)
# Read what's under the pointer and show the decimal value.
if placeholder is None:
float_size = 8 if "qword" in match.string else 4
placeholder = self.float_replace(value, float_size)
# If we can't read the float, use a regular placeholder.
if placeholder is None:
placeholder = self.replace(value)
return match.group(0).replace(match.group(1), placeholder)
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]: def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
if len(inst.op_str) == 0:
# Nothing to sanitize
return (inst.mnemonic, "")
if "0x" not in inst.op_str:
return (inst.mnemonic, inst.op_str)
# For jumps or calls, if the entire op_str is a hex number, the value # For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset. # is a relative offset.
# Otherwise (i.e. it looks like `dword ptr [address]`) it is an # Otherwise (i.e. it looks like `dword ptr [address]`) it is an
@ -109,12 +137,21 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# Providing the starting address of the function to capstone.disasm has # Providing the starting address of the function to capstone.disasm has
# automatically resolved relative offsets to an absolute address. # automatically resolved relative offsets to an absolute address.
# We will have to undo this for some of the jumps or they will not match. # We will have to undo this for some of the jumps or they will not match.
op_str_address = from_hex(inst.op_str)
if op_str_address is not None: if (
inst.mnemonic in SINGLE_OPERAND_INSTS
and (op_str_address := from_hex(inst.op_str)) is not None
):
if inst.mnemonic == "call": if inst.mnemonic == "call":
return (inst.mnemonic, self.replace(op_str_address)) return (inst.mnemonic, self.replace(op_str_address))
if inst.mnemonic == "push":
if self.is_relocated(op_str_address):
return (inst.mnemonic, self.replace(op_str_address))
# To avoid falling into jump handling
return (inst.mnemonic, inst.op_str)
if inst.mnemonic == "jmp": if inst.mnemonic == "jmp":
# The unwind section contains JMPs to other functions. # The unwind section contains JMPs to other functions.
# If we have a name for this address, use it. If not, # If we have a name for this address, use it. If not,
@ -124,70 +161,19 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
if potential_name is not None: if potential_name is not None:
return (inst.mnemonic, potential_name) return (inst.mnemonic, potential_name)
if inst.mnemonic.startswith("j"): # Else: this is any jump
# i.e. if this is any jump # Show the jump offset rather than the absolute address
# Show the jump offset rather than the absolute address jump_displacement = op_str_address - (inst.address + inst.size)
jump_displacement = op_str_address - (inst.address + inst.size) return (inst.mnemonic, hex(jump_displacement))
return (inst.mnemonic, hex(jump_displacement))
def filter_out_ptr(match):
"""Helper for re.sub, see below"""
offset = from_hex(match.group("addr"))
if offset is not None:
# We assume this is always an address to replace
placeholder = self.replace(offset)
return f'{match.group("data_size")} ptr [{placeholder}]'
# Strict regex should ensure we can read the hex number.
# But just in case: return the string with no changes
return match.group(0)
def float_ptr_replace(match):
offset = from_hex(match.group("addr"))
if offset is not None:
# If we can find a variable name for this pointer, use it.
placeholder = self.lookup(offset)
# Read what's under the pointer and show the decimal value.
if placeholder is None:
placeholder = self.float_replace(
offset, get_float_size(match.group("data_size"))
)
# If we can't read the float, use a regular placeholder.
if placeholder is None:
placeholder = self.replace(offset)
return f'{match.group("data_size")} ptr [{placeholder}]'
# Strict regex should ensure we can read the hex number.
# But just in case: return the string with no changes
return match.group(0)
if inst.mnemonic.startswith("f"): if inst.mnemonic.startswith("f"):
# If floating point instruction # If floating point instruction
op_str = ptr_replace_regex.sub(float_ptr_replace, inst.op_str) op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
else: else:
op_str = ptr_replace_regex.sub(filter_out_ptr, inst.op_str) op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)
def replace_immediate(chunk: str) -> str: op_str = immediate_replace_regex.sub(self.hex_replace_relocated, op_str)
if (inttest := from_hex(chunk)) is not None: return (inst.mnemonic, op_str)
# If this value is a virtual address, it is referenced absolutely,
# which means it must be in the relocation table.
if self.is_relocated(inttest):
return self.replace(inttest)
return chunk
# Performance hack:
# Skip this step if there is nothing left to consider replacing.
if "0x" in op_str:
# Replace immediate values with name or placeholder (where appropriate)
op_str = ", ".join(map(replace_immediate, op_str.split(", ")))
return inst.mnemonic, op_str
def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]: def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]:
asm = [] asm = []
@ -196,7 +182,22 @@ def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]:
# Use heuristics to disregard some differences that aren't representative # Use heuristics to disregard some differences that aren't representative
# of the accuracy of a function (e.g. global offsets) # of the accuracy of a function (e.g. global offsets)
inst = DisasmLiteInst(*raw_inst) inst = DisasmLiteInst(*raw_inst)
result = self.sanitize(inst)
# If there is no pointer or immediate value in the op_str,
# there is nothing to sanitize.
# This leaves us with cases where a small immediate value or
# small displacement (this.member or vtable calls) appears.
# If we assume that instructions we want to sanitize need to be 5
# bytes -- 1 for the opcode and 4 for the address -- exclude cases
# where the hex value could not be an address.
# The exception is jumps which are as small as 2 bytes
# but are still useful to sanitize.
if "0x" in inst.op_str and (
inst.mnemonic in JUMP_MNEMONICS or inst.size > 4
):
result = self.sanitize(inst)
else:
result = (inst.mnemonic, inst.op_str)
# mnemonic + " " + op_str # mnemonic + " " + op_str
asm.append((hex(inst.address), " ".join(result))) asm.append((hex(inst.address), " ".join(result)))

View file

@ -158,9 +158,9 @@ def _load_cvdump(self):
addr, sym.node_type, sym.name(), sym.decorated_name, sym.size() addr, sym.node_type, sym.name(), sym.decorated_name, sym.size()
) )
for lineref in cv.lines: for (section, offset), (filename, line_no) in res.verified_lines.items():
addr = self.recomp_bin.get_abs_addr(lineref.section, lineref.offset) addr = self.recomp_bin.get_abs_addr(section, offset)
self._lines_db.add_line(lineref.filename, lineref.line_no, addr) self._lines_db.add_line(filename, line_no, addr)
# The _entry symbol is referenced in the PE header so we get this match for free. # The _entry symbol is referenced in the PE header so we get this match for free.
self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry) self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)

View file

@ -1,5 +1,5 @@
"""For collating the results from parsing cvdump.exe into a more directly useful format.""" """For collating the results from parsing cvdump.exe into a more directly useful format."""
from typing import List, Optional from typing import Dict, List, Tuple, Optional
from isledecomp.types import SymbolType from isledecomp.types import SymbolType
from .parser import CvdumpParser from .parser import CvdumpParser
from .demangler import demangle_string_const, demangle_vtable from .demangler import demangle_string_const, demangle_vtable
@ -81,6 +81,7 @@ class CvdumpAnalysis:
These can then be analyzed by a downstream tool.""" These can then be analyzed by a downstream tool."""
nodes = List[CvdumpNode] nodes = List[CvdumpNode]
verified_lines = Dict[Tuple[str, str], Tuple[str, str]]
def __init__(self, parser: CvdumpParser): def __init__(self, parser: CvdumpParser):
"""Read in as much information as we have from the parser. """Read in as much information as we have from the parser.
@ -126,13 +127,21 @@ def __init__(self, parser: CvdumpParser):
# No big deal if we don't have complete type information. # No big deal if we don't have complete type information.
pass pass
for lin in parser.lines: for key, _ in parser.lines.items():
key = (lin.section, lin.offset)
# Here we only set if the section:offset already exists # Here we only set if the section:offset already exists
# because our values include offsets inside of the function. # because our values include offsets inside of the function.
if key in node_dict: if key in node_dict:
node_dict[key].node_type = SymbolType.FUNCTION node_dict[key].node_type = SymbolType.FUNCTION
# The LINES section contains every code line in the file, naturally.
# There isn't an obvious separation between functions, so we have to
# read everything. However, any function that would be in LINES
# has to be somewhere else in the PDB (probably PUBLICS).
# Isolate the lines that we actually care about for matching.
self.verified_lines = {
key: value for (key, value) in parser.lines.items() if key in node_dict
}
for sym in parser.symbols: for sym in parser.symbols:
key = (sym.section, sym.offset) key = (sym.section, sym.offset)
if key not in node_dict: if key not in node_dict:

View file

@ -4,7 +4,7 @@
from .types import CvdumpTypesParser from .types import CvdumpTypesParser
# e.g. `*** PUBLICS` # e.g. `*** PUBLICS`
_section_change_regex = re.compile(r"^\*\*\* (?P<section>[A-Z/ ]+)$") _section_change_regex = re.compile(r"\*\*\* (?P<section>[A-Z/ ]{2,})")
# e.g. ` 27 00034EC0 28 00034EE2 29 00034EE7 30 00034EF4` # e.g. ` 27 00034EC0 28 00034EE2 29 00034EE7 30 00034EF4`
_line_addr_pairs_findall = re.compile(r"\s+(?P<line_no>\d+) (?P<addr>[A-F0-9]{8})") _line_addr_pairs_findall = re.compile(r"\s+(?P<line_no>\d+) (?P<addr>[A-F0-9]{8})")
@ -70,7 +70,7 @@ def __init__(self) -> None:
self._section: str = "" self._section: str = ""
self._lines_function: Tuple[str, int] = ("", 0) self._lines_function: Tuple[str, int] = ("", 0)
self.lines = [] self.lines = {}
self.publics = [] self.publics = []
self.symbols = [] self.symbols = []
self.sizerefs = [] self.sizerefs = []
@ -95,14 +95,8 @@ def _lines_section(self, line: str):
# Match any pairs as we find them # Match any pairs as we find them
for line_no, offset in _line_addr_pairs_findall.findall(line): for line_no, offset in _line_addr_pairs_findall.findall(line):
self.lines.append( key = (self._lines_function[1], int(offset, 16))
LinesEntry( self.lines[key] = (self._lines_function[0], int(line_no))
filename=self._lines_function[0],
line_no=int(line_no),
section=self._lines_function[1],
offset=int(offset, 16),
)
)
def _publics_section(self, line: str): def _publics_section(self, line: str):
"""Match each line from PUBLICS and pull out the symbol information. """Match each line from PUBLICS and pull out the symbol information.
@ -175,23 +169,22 @@ def _modules_section(self, line: str):
) )
def read_line(self, line: str): def read_line(self, line: str):
# Blank lines are there to help the reader; they have no context significance
if line.strip() == "":
return
if (match := _section_change_regex.match(line)) is not None: if (match := _section_change_regex.match(line)) is not None:
self._section = match.group(1) self._section = match.group(1)
return return
if self._section == "LINES": if self._section == "TYPES":
self.types.read_line(line)
elif self._section == "SYMBOLS":
self._symbols_section(line)
elif self._section == "LINES":
self._lines_section(line) self._lines_section(line)
elif self._section == "PUBLICS": elif self._section == "PUBLICS":
self._publics_section(line) self._publics_section(line)
elif self._section == "SYMBOLS":
self._symbols_section(line)
elif self._section == "SECTION CONTRIBUTIONS": elif self._section == "SECTION CONTRIBUTIONS":
self._section_contributions(line) self._section_contributions(line)
@ -201,9 +194,6 @@ def read_line(self, line: str):
elif self._section == "MODULES": elif self._section == "MODULES":
self._modules_section(line) self._modules_section(line)
elif self._section == "TYPES":
self.types.read_line(line)
def read_lines(self, lines: Iterable[str]): def read_lines(self, lines: Iterable[str]):
for line in lines: for line in lines:
self.read_line(line) self.read_line(line)

View file

@ -1,3 +1,4 @@
import io
from os import name as os_name from os import name as os_name
from enum import Enum from enum import Enum
from typing import List from typing import List
@ -71,8 +72,12 @@ def cmd_line(self) -> List[str]:
return ["wine", cvdump_exe, *flags, winepath_unix_to_win(self._pdb)] return ["wine", cvdump_exe, *flags, winepath_unix_to_win(self._pdb)]
def run(self) -> CvdumpParser: def run(self) -> CvdumpParser:
p = CvdumpParser() parser = CvdumpParser()
call = self.cmd_line() call = self.cmd_line()
lines = subprocess.check_output(call).decode("utf-8").split("\r\n") with subprocess.Popen(call, stdout=subprocess.PIPE) as proc:
p.read_lines(lines) for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"):
return p # Blank lines are there to help the reader; they have no context significance
if line != "\n":
parser.read_line(line)
return parser

View file

@ -56,11 +56,11 @@ def normalize_type_id(key: str) -> str:
If key begins with "T_" it is a built-in type. If key begins with "T_" it is a built-in type.
Else it is a hex string. We prefer lower case letters and Else it is a hex string. We prefer lower case letters and
no leading zeroes. (UDT identifier pads to 8 characters.)""" no leading zeroes. (UDT identifier pads to 8 characters.)"""
if key.startswith("T_"): if key[0] == "0":
# Remove numeric value for "T_" type. We don't use this. return f"0x{key[-4:].lower()}"
return key[: key.index("(")] if "(" in key else key
return hex(int(key, 16)).lower() # Remove numeric value for "T_" type. We don't use this.
return key.partition("(")[0]
def scalar_type_pointer(type_name: str) -> bool: def scalar_type_pointer(type_name: str) -> bool:
@ -203,8 +203,18 @@ class CvdumpTypesParser:
# LF_MODIFIER, type being modified # LF_MODIFIER, type being modified
MODIFIES_RE = re.compile(r".*modifies type (?P<type>.*)$") MODIFIES_RE = re.compile(r".*modifies type (?P<type>.*)$")
MODES_OF_INTEREST = {
"LF_ARRAY",
"LF_CLASS",
"LF_ENUM",
"LF_FIELDLIST",
"LF_MODIFIER",
"LF_POINTER",
"LF_STRUCTURE",
}
def __init__(self) -> None: def __init__(self) -> None:
self.mode = "" self.mode: Optional[str] = None
self.last_key = "" self.last_key = ""
self.keys = {} self.keys = {}
@ -370,13 +380,19 @@ def get_format_string(self, type_key: str) -> str:
def read_line(self, line: str): def read_line(self, line: str):
if (match := self.INDEX_RE.match(line)) is not None: if (match := self.INDEX_RE.match(line)) is not None:
self.last_key = normalize_type_id(match.group("key")) type_ = match.group(2)
self.mode = match.group("type") if type_ not in self.MODES_OF_INTEREST:
self._new_type() self.mode = None
return
# We don't need to read anything else from here (for now) # Don't need to normalize, it's already in the format we want
if self.mode in ("LF_ENUM", "LF_POINTER"): self.last_key = match.group(1)
self._set("size", 4) self.mode = type_
self._new_type()
return
if self.mode is None:
return
if self.mode == "LF_MODIFIER": if self.mode == "LF_MODIFIER":
if (match := self.MODIFIES_RE.match(line)) is not None: if (match := self.MODIFIES_RE.match(line)) is not None:
@ -385,14 +401,14 @@ def read_line(self, line: str):
self._set("is_forward_ref", True) self._set("is_forward_ref", True)
self._set("modifies", normalize_type_id(match.group("type"))) self._set("modifies", normalize_type_id(match.group("type")))
if self.mode == "LF_ARRAY": elif self.mode == "LF_ARRAY":
if (match := self.ARRAY_ELEMENT_RE.match(line)) is not None: if (match := self.ARRAY_ELEMENT_RE.match(line)) is not None:
self._set("array_type", normalize_type_id(match.group("type"))) self._set("array_type", normalize_type_id(match.group("type")))
if (match := self.ARRAY_LENGTH_RE.match(line)) is not None: elif (match := self.ARRAY_LENGTH_RE.match(line)) is not None:
self._set("size", int(match.group("length"))) self._set("size", int(match.group("length")))
if self.mode == "LF_FIELDLIST": elif self.mode == "LF_FIELDLIST":
# If this class has a vtable, create a mock member at offset 0 # If this class has a vtable, create a mock member at offset 0
if (match := self.VTABLE_RE.match(line)) is not None: if (match := self.VTABLE_RE.match(line)) is not None:
# For our purposes, any pointer type will do # For our purposes, any pointer type will do
@ -400,20 +416,20 @@ def read_line(self, line: str):
self._set_member_name("vftable") self._set_member_name("vftable")
# Superclass is set here in the fieldlist rather than in LF_CLASS # Superclass is set here in the fieldlist rather than in LF_CLASS
if (match := self.SUPERCLASS_RE.match(line)) is not None: elif (match := self.SUPERCLASS_RE.match(line)) is not None:
self._set("super", normalize_type_id(match.group("type"))) self._set("super", normalize_type_id(match.group("type")))
# Member offset and type given on the first of two lines. # Member offset and type given on the first of two lines.
if (match := self.LIST_RE.match(line)) is not None: elif (match := self.LIST_RE.match(line)) is not None:
self._add_member( self._add_member(
int(match.group("offset")), normalize_type_id(match.group("type")) int(match.group("offset")), normalize_type_id(match.group("type"))
) )
# Name of the member read on the second of two lines. # Name of the member read on the second of two lines.
if (match := self.MEMBER_RE.match(line)) is not None: elif (match := self.MEMBER_RE.match(line)) is not None:
self._set_member_name(match.group("name")) self._set_member_name(match.group("name"))
if self.mode in ("LF_STRUCTURE", "LF_CLASS"): else: # LF_CLASS or LF_STRUCTURE
# Match the reference to the associated LF_FIELDLIST # Match the reference to the associated LF_FIELDLIST
if (match := self.CLASS_FIELD_RE.match(line)) is not None: if (match := self.CLASS_FIELD_RE.match(line)) is not None:
if match.group("field_type") == "0x0000": if match.group("field_type") == "0x0000":
@ -427,7 +443,7 @@ def read_line(self, line: str):
# Last line has the vital information. # Last line has the vital information.
# If this is a FORWARD REF, we need to follow the UDT pointer # If this is a FORWARD REF, we need to follow the UDT pointer
# to get the actual class details. # to get the actual class details.
if (match := self.CLASS_NAME_RE.match(line)) is not None: elif (match := self.CLASS_NAME_RE.match(line)) is not None:
self._set("name", match.group("name")) self._set("name", match.group("name"))
self._set("udt", normalize_type_id(match.group("udt"))) self._set("udt", normalize_type_id(match.group("udt")))
self._set("size", int(match.group("size"))) self._set("size", int(match.group("size")))