From 5eb74c06fd7a62b11d9205cdf50ced90b6f51d0a Mon Sep 17 00:00:00 2001
From: MS <disinvite@users.noreply.github.com>
Date: Sun, 10 Mar 2024 14:49:45 -0400
Subject: [PATCH] reccmp: Sanitize performance (and more) (#654)

---
 .../isledecomp/compare/asm/const.py           |  27 ++++
 .../isledecomp/compare/asm/parse.py           | 149 +++++++++---------
 tools/isledecomp/isledecomp/compare/core.py   |   6 +-
 .../isledecomp/isledecomp/cvdump/analysis.py  |  15 +-
 tools/isledecomp/isledecomp/cvdump/parser.py  |  32 ++--
 tools/isledecomp/isledecomp/cvdump/runner.py  |  13 +-
 tools/isledecomp/isledecomp/cvdump/types.py   |  54 ++++---
 7 files changed, 172 insertions(+), 124 deletions(-)
 create mode 100644 tools/isledecomp/isledecomp/compare/asm/const.py

diff --git a/tools/isledecomp/isledecomp/compare/asm/const.py b/tools/isledecomp/isledecomp/compare/asm/const.py
new file mode 100644
index 00000000..7c715c0c
--- /dev/null
+++ b/tools/isledecomp/isledecomp/compare/asm/const.py
@@ -0,0 +1,27 @@
+# Duplicates removed, according to the mnemonics capstone uses.
+# e.g. je and jz are the same instruction. capstone uses je.
+# See: /arch/X86/X86GenAsmWriter.inc in the capstone repo.
+JUMP_MNEMONICS = {
+    "ja",
+    "jae",
+    "jb",
+    "jbe",
+    "jcxz",  # unused?
+    "je",
+    "jecxz",
+    "jg",
+    "jge",
+    "jl",
+    "jle",
+    "jmp",
+    "jne",
+    "jno",
+    "jnp",
+    "jns",
+    "jo",
+    "jp",
+    "js",
+}
+
+# Guaranteed to be a single operand.
+SINGLE_OPERAND_INSTS = {"push", "call", *JUMP_MNEMONICS}
diff --git a/tools/isledecomp/isledecomp/compare/asm/parse.py b/tools/isledecomp/isledecomp/compare/asm/parse.py
index 35451fc6..c57f2df9 100644
--- a/tools/isledecomp/isledecomp/compare/asm/parse.py
+++ b/tools/isledecomp/isledecomp/compare/asm/parse.py
@@ -12,10 +12,15 @@ from typing import Callable, List, Optional, Tuple
 from collections import namedtuple
 from isledecomp.bin import InvalidVirtualAddressError
 from capstone import Cs, CS_ARCH_X86, CS_MODE_32
+from .const import JUMP_MNEMONICS, SINGLE_OPERAND_INSTS
 
 disassembler = Cs(CS_ARCH_X86, CS_MODE_32)
 
-ptr_replace_regex = re.compile(r"(?P<data_size>\w+) ptr \[(?P<addr>0x[0-9a-fA-F]+)\]")
+ptr_replace_regex = re.compile(r"\[(0x[0-9a-f]+)\]")
+
+# For matching an immediate value on its own.
+# Preceded by start-of-string (first operand) or comma-space (second operand)
+immediate_replace_regex = re.compile(r"(?:^|, )(0x[0-9a-f]+)")
 
 DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str")
 
@@ -30,10 +35,6 @@ def from_hex(string: str) -> Optional[int]:
     return None
 
 
-def get_float_size(size_str: str) -> int:
-    return 8 if size_str == "qword" else 4
-
-
 class ParseAsm:
     def __init__(
         self,
@@ -94,14 +95,41 @@ class ParseAsm:
         self.replacements[addr] = placeholder
         return placeholder
 
+    def hex_replace_always(self, match: re.Match) -> str:
+        """If a pointer value was matched, always insert a placeholder"""
+        value = int(match.group(1), 16)
+        return match.group(0).replace(match.group(1), self.replace(value))
+
+    def hex_replace_relocated(self, match: re.Match) -> str:
+        """For replacing immediate value operands. We only want to
+        use the placeholder if we are certain that this is a valid address.
+        We can check the relocation table to find out."""
+        value = int(match.group(1), 16)
+        if self.is_relocated(value):
+            return match.group(0).replace(match.group(1), self.replace(value))
+
+        return match.group(0)
+
+    def hex_replace_float(self, match: re.Match) -> str:
+        """Special case for replacements on float instructions.
+        If the pointer is a float constant, read it from the binary."""
+        value = int(match.group(1), 16)
+
+        # If we can find a variable name for this pointer, use it.
+        placeholder = self.lookup(value)
+
+        # Read what's under the pointer and show the decimal value.
+        if placeholder is None:
+            float_size = 8 if "qword" in match.string else 4
+            placeholder = self.float_replace(value, float_size)
+
+        # If we can't read the float, use a regular placeholder.
+        if placeholder is None:
+            placeholder = self.replace(value)
+
+        return match.group(0).replace(match.group(1), placeholder)
+
     def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
-        if len(inst.op_str) == 0:
-            # Nothing to sanitize
-            return (inst.mnemonic, "")
-
-        if "0x" not in inst.op_str:
-            return (inst.mnemonic, inst.op_str)
-
         # For jumps or calls, if the entire op_str is a hex number, the value
         # is a relative offset.
         # Otherwise (i.e. it looks like `dword ptr [address]`) it is an
@@ -109,12 +137,21 @@ class ParseAsm:
         # Providing the starting address of the function to capstone.disasm has
         # automatically resolved relative offsets to an absolute address.
         # We will have to undo this for some of the jumps or they will not match.
-        op_str_address = from_hex(inst.op_str)
 
-        if op_str_address is not None:
+        if (
+            inst.mnemonic in SINGLE_OPERAND_INSTS
+            and (op_str_address := from_hex(inst.op_str)) is not None
+        ):
             if inst.mnemonic == "call":
                 return (inst.mnemonic, self.replace(op_str_address))
 
+            if inst.mnemonic == "push":
+                if self.is_relocated(op_str_address):
+                    return (inst.mnemonic, self.replace(op_str_address))
+
+                # To avoid falling into jump handling
+                return (inst.mnemonic, inst.op_str)
+
             if inst.mnemonic == "jmp":
                 # The unwind section contains JMPs to other functions.
                 # If we have a name for this address, use it. If not,
@@ -124,70 +161,19 @@ class ParseAsm:
                 if potential_name is not None:
                     return (inst.mnemonic, potential_name)
 
-            if inst.mnemonic.startswith("j"):
-                # i.e. if this is any jump
-                # Show the jump offset rather than the absolute address
-                jump_displacement = op_str_address - (inst.address + inst.size)
-                return (inst.mnemonic, hex(jump_displacement))
-
-        def filter_out_ptr(match):
-            """Helper for re.sub, see below"""
-            offset = from_hex(match.group("addr"))
-
-            if offset is not None:
-                # We assume this is always an address to replace
-                placeholder = self.replace(offset)
-                return f'{match.group("data_size")} ptr [{placeholder}]'
-
-            # Strict regex should ensure we can read the hex number.
-            # But just in case: return the string with no changes
-            return match.group(0)
-
-        def float_ptr_replace(match):
-            offset = from_hex(match.group("addr"))
-
-            if offset is not None:
-                # If we can find a variable name for this pointer, use it.
-                placeholder = self.lookup(offset)
-
-                # Read what's under the pointer and show the decimal value.
-                if placeholder is None:
-                    placeholder = self.float_replace(
-                        offset, get_float_size(match.group("data_size"))
-                    )
-
-                # If we can't read the float, use a regular placeholder.
-                if placeholder is None:
-                    placeholder = self.replace(offset)
-
-                return f'{match.group("data_size")} ptr [{placeholder}]'
-
-            # Strict regex should ensure we can read the hex number.
-            # But just in case: return the string with no changes
-            return match.group(0)
+            # Else: this is any jump
+            # Show the jump offset rather than the absolute address
+            jump_displacement = op_str_address - (inst.address + inst.size)
+            return (inst.mnemonic, hex(jump_displacement))
 
         if inst.mnemonic.startswith("f"):
             # If floating point instruction
-            op_str = ptr_replace_regex.sub(float_ptr_replace, inst.op_str)
+            op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
         else:
-            op_str = ptr_replace_regex.sub(filter_out_ptr, inst.op_str)
+            op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)
 
-        def replace_immediate(chunk: str) -> str:
-            if (inttest := from_hex(chunk)) is not None:
-                # If this value is a virtual address, it is referenced absolutely,
-                # which means it must be in the relocation table.
-                if self.is_relocated(inttest):
-                    return self.replace(inttest)
-
-            return chunk
-
-        # Performance hack:
-        # Skip this step if there is nothing left to consider replacing.
-        if "0x" in op_str:
-            # Replace immediate values with name or placeholder (where appropriate)
-            op_str = ", ".join(map(replace_immediate, op_str.split(", ")))
-
-        return inst.mnemonic, op_str
+        op_str = immediate_replace_regex.sub(self.hex_replace_relocated, op_str)
+        return (inst.mnemonic, op_str)
 
     def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]:
         asm = []
@@ -196,7 +182,22 @@ class ParseAsm:
             # Use heuristics to disregard some differences that aren't representative
             # of the accuracy of a function (e.g. global offsets)
             inst = DisasmLiteInst(*raw_inst)
-            result = self.sanitize(inst)
+
+            # If there is no pointer or immediate value in the op_str,
+            # there is nothing to sanitize.
+            # This leaves us with cases where a small immediate value or
+            # small displacement (this.member or vtable calls) appears.
+            # If we assume that instructions we want to sanitize need to be 5
+            # bytes -- 1 for the opcode and 4 for the address -- exclude cases
+            # where the hex value could not be an address.
+            # The exception is jumps which are as small as 2 bytes
+            # but are still useful to sanitize.
+            if "0x" in inst.op_str and (
+                inst.mnemonic in JUMP_MNEMONICS or inst.size > 4
+            ):
+                result = self.sanitize(inst)
+            else:
+                result = (inst.mnemonic, inst.op_str)
 
             # mnemonic + " " + op_str
             asm.append((hex(inst.address), " ".join(result)))
diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py
index 53ba3fa0..ca48cb63 100644
--- a/tools/isledecomp/isledecomp/compare/core.py
+++ b/tools/isledecomp/isledecomp/compare/core.py
@@ -158,9 +158,9 @@ class Compare:
                 addr, sym.node_type, sym.name(), sym.decorated_name, sym.size()
             )
 
-        for lineref in cv.lines:
-            addr = self.recomp_bin.get_abs_addr(lineref.section, lineref.offset)
-            self._lines_db.add_line(lineref.filename, lineref.line_no, addr)
+        for (section, offset), (filename, line_no) in res.verified_lines.items():
+            addr = self.recomp_bin.get_abs_addr(section, offset)
+            self._lines_db.add_line(filename, line_no, addr)
 
         # The _entry symbol is referenced in the PE header so we get this match for free.
         self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)
diff --git a/tools/isledecomp/isledecomp/cvdump/analysis.py b/tools/isledecomp/isledecomp/cvdump/analysis.py
index d3f8bd27..4ef654c5 100644
--- a/tools/isledecomp/isledecomp/cvdump/analysis.py
+++ b/tools/isledecomp/isledecomp/cvdump/analysis.py
@@ -1,5 +1,5 @@
 """For collating the results from parsing cvdump.exe into a more directly useful format."""
-from typing import List, Optional
+from typing import Dict, List, Tuple, Optional
 from isledecomp.types import SymbolType
 from .parser import CvdumpParser
 from .demangler import demangle_string_const, demangle_vtable
@@ -81,6 +81,7 @@ class CvdumpAnalysis:
     These can then be analyzed by a downstream tool."""
 
     nodes = List[CvdumpNode]
+    verified_lines = Dict[Tuple[str, str], Tuple[str, str]]
 
     def __init__(self, parser: CvdumpParser):
         """Read in as much information as we have from the parser.
@@ -126,13 +127,21 @@ class CvdumpAnalysis:
                 # No big deal if we don't have complete type information.
                 pass
 
-        for lin in parser.lines:
-            key = (lin.section, lin.offset)
+        for key, _ in parser.lines.items():
             # Here we only set if the section:offset already exists
             # because our values include offsets inside of the function.
             if key in node_dict:
                 node_dict[key].node_type = SymbolType.FUNCTION
 
+        # The LINES section contains every code line in the file, naturally.
+        # There isn't an obvious separation between functions, so we have to
+        # read everything. However, any function that would be in LINES
+        # has to be somewhere else in the PDB (probably PUBLICS).
+        # Isolate the lines that we actually care about for matching.
+        self.verified_lines = {
+            key: value for (key, value) in parser.lines.items() if key in node_dict
+        }
+
         for sym in parser.symbols:
             key = (sym.section, sym.offset)
             if key not in node_dict:
diff --git a/tools/isledecomp/isledecomp/cvdump/parser.py b/tools/isledecomp/isledecomp/cvdump/parser.py
index 27554eda..1b1eb3fd 100644
--- a/tools/isledecomp/isledecomp/cvdump/parser.py
+++ b/tools/isledecomp/isledecomp/cvdump/parser.py
@@ -4,7 +4,7 @@ from collections import namedtuple
 from .types import CvdumpTypesParser
 
 # e.g. `*** PUBLICS`
-_section_change_regex = re.compile(r"^\*\*\* (?P<section>[A-Z/ ]+)$")
+_section_change_regex = re.compile(r"\*\*\* (?P<section>[A-Z/ ]{2,})")
 
 # e.g. `     27 00034EC0     28 00034EE2     29 00034EE7     30 00034EF4`
 _line_addr_pairs_findall = re.compile(r"\s+(?P<line_no>\d+) (?P<addr>[A-F0-9]{8})")
@@ -70,7 +70,7 @@ class CvdumpParser:
         self._section: str = ""
         self._lines_function: Tuple[str, int] = ("", 0)
 
-        self.lines = []
+        self.lines = {}
         self.publics = []
         self.symbols = []
         self.sizerefs = []
@@ -95,14 +95,8 @@ class CvdumpParser:
 
         # Match any pairs as we find them
         for line_no, offset in _line_addr_pairs_findall.findall(line):
-            self.lines.append(
-                LinesEntry(
-                    filename=self._lines_function[0],
-                    line_no=int(line_no),
-                    section=self._lines_function[1],
-                    offset=int(offset, 16),
-                )
-            )
+            key = (self._lines_function[1], int(offset, 16))
+            self.lines[key] = (self._lines_function[0], int(line_no))
 
     def _publics_section(self, line: str):
         """Match each line from PUBLICS and pull out the symbol information.
@@ -175,23 +169,22 @@ class CvdumpParser:
             )
 
     def read_line(self, line: str):
-        # Blank lines are there to help the reader; they have no context significance
-        if line.strip() == "":
-            return
-
         if (match := _section_change_regex.match(line)) is not None:
             self._section = match.group(1)
             return
 
-        if self._section == "LINES":
+        if self._section == "TYPES":
+            self.types.read_line(line)
+
+        elif self._section == "SYMBOLS":
+            self._symbols_section(line)
+
+        elif self._section == "LINES":
             self._lines_section(line)
 
         elif self._section == "PUBLICS":
             self._publics_section(line)
 
-        elif self._section == "SYMBOLS":
-            self._symbols_section(line)
-
         elif self._section == "SECTION CONTRIBUTIONS":
             self._section_contributions(line)
 
@@ -201,9 +194,6 @@ class CvdumpParser:
         elif self._section == "MODULES":
             self._modules_section(line)
 
-        elif self._section == "TYPES":
-            self.types.read_line(line)
-
     def read_lines(self, lines: Iterable[str]):
         for line in lines:
             self.read_line(line)
diff --git a/tools/isledecomp/isledecomp/cvdump/runner.py b/tools/isledecomp/isledecomp/cvdump/runner.py
index 33e2d98d..9463acfa 100644
--- a/tools/isledecomp/isledecomp/cvdump/runner.py
+++ b/tools/isledecomp/isledecomp/cvdump/runner.py
@@ -1,3 +1,4 @@
+import io
 from os import name as os_name
 from enum import Enum
 from typing import List
@@ -71,8 +72,12 @@ class Cvdump:
         return ["wine", cvdump_exe, *flags, winepath_unix_to_win(self._pdb)]
 
     def run(self) -> CvdumpParser:
-        p = CvdumpParser()
+        parser = CvdumpParser()
         call = self.cmd_line()
-        lines = subprocess.check_output(call).decode("utf-8").split("\r\n")
-        p.read_lines(lines)
-        return p
+        with subprocess.Popen(call, stdout=subprocess.PIPE) as proc:
+            for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"):
+                # Blank lines are there to help the reader; they have no context significance
+                if line != "\n":
+                    parser.read_line(line)
+
+        return parser
diff --git a/tools/isledecomp/isledecomp/cvdump/types.py b/tools/isledecomp/isledecomp/cvdump/types.py
index ed5a38b8..87ae4b6e 100644
--- a/tools/isledecomp/isledecomp/cvdump/types.py
+++ b/tools/isledecomp/isledecomp/cvdump/types.py
@@ -56,11 +56,11 @@ def normalize_type_id(key: str) -> str:
     If key begins with "T_" it is a built-in type.
     Else it is a hex string. We prefer lower case letters and
     no leading zeroes. (UDT identifier pads to 8 characters.)"""
-    if key.startswith("T_"):
-        # Remove numeric value for "T_" type. We don't use this.
-        return key[: key.index("(")] if "(" in key else key
+    if key[0] == "0":
+        return f"0x{key[-4:].lower()}"
 
-    return hex(int(key, 16)).lower()
+    # Remove numeric value for "T_" type. We don't use this.
+    return key.partition("(")[0]
 
 
 def scalar_type_pointer(type_name: str) -> bool:
@@ -203,8 +203,18 @@ class CvdumpTypesParser:
     # LF_MODIFIER, type being modified
     MODIFIES_RE = re.compile(r".*modifies type (?P<type>.*)$")
 
+    MODES_OF_INTEREST = {
+        "LF_ARRAY",
+        "LF_CLASS",
+        "LF_ENUM",
+        "LF_FIELDLIST",
+        "LF_MODIFIER",
+        "LF_POINTER",
+        "LF_STRUCTURE",
+    }
+
     def __init__(self) -> None:
-        self.mode = ""
+        self.mode: Optional[str] = None
         self.last_key = ""
         self.keys = {}
 
@@ -370,13 +380,19 @@ class CvdumpTypesParser:
 
     def read_line(self, line: str):
         if (match := self.INDEX_RE.match(line)) is not None:
-            self.last_key = normalize_type_id(match.group("key"))
-            self.mode = match.group("type")
-            self._new_type()
+            type_ = match.group(2)
+            if type_ not in self.MODES_OF_INTEREST:
+                self.mode = None
+                return
 
-            # We don't need to read anything else from here (for now)
-            if self.mode in ("LF_ENUM", "LF_POINTER"):
-                self._set("size", 4)
+            # Don't need to normalize, it's already in the format we want
+            self.last_key = match.group(1)
+            self.mode = type_
+            self._new_type()
+            return
+
+        if self.mode is None:
+            return
 
         if self.mode == "LF_MODIFIER":
             if (match := self.MODIFIES_RE.match(line)) is not None:
@@ -385,14 +401,14 @@ class CvdumpTypesParser:
                 self._set("is_forward_ref", True)
                 self._set("modifies", normalize_type_id(match.group("type")))
 
-        if self.mode == "LF_ARRAY":
+        elif self.mode == "LF_ARRAY":
             if (match := self.ARRAY_ELEMENT_RE.match(line)) is not None:
                 self._set("array_type", normalize_type_id(match.group("type")))
 
-            if (match := self.ARRAY_LENGTH_RE.match(line)) is not None:
+            elif (match := self.ARRAY_LENGTH_RE.match(line)) is not None:
                 self._set("size", int(match.group("length")))
 
-        if self.mode == "LF_FIELDLIST":
+        elif self.mode == "LF_FIELDLIST":
             # If this class has a vtable, create a mock member at offset 0
             if (match := self.VTABLE_RE.match(line)) is not None:
                 # For our purposes, any pointer type will do
@@ -400,20 +416,20 @@ class CvdumpTypesParser:
                 self._set_member_name("vftable")
 
             # Superclass is set here in the fieldlist rather than in LF_CLASS
-            if (match := self.SUPERCLASS_RE.match(line)) is not None:
+            elif (match := self.SUPERCLASS_RE.match(line)) is not None:
                 self._set("super", normalize_type_id(match.group("type")))
 
             # Member offset and type given on the first of two lines.
-            if (match := self.LIST_RE.match(line)) is not None:
+            elif (match := self.LIST_RE.match(line)) is not None:
                 self._add_member(
                     int(match.group("offset")), normalize_type_id(match.group("type"))
                 )
 
             # Name of the member read on the second of two lines.
-            if (match := self.MEMBER_RE.match(line)) is not None:
+            elif (match := self.MEMBER_RE.match(line)) is not None:
                 self._set_member_name(match.group("name"))
 
-        if self.mode in ("LF_STRUCTURE", "LF_CLASS"):
+        else:  # LF_CLASS or LF_STRUCTURE
             # Match the reference to the associated LF_FIELDLIST
             if (match := self.CLASS_FIELD_RE.match(line)) is not None:
                 if match.group("field_type") == "0x0000":
@@ -427,7 +443,7 @@ class CvdumpTypesParser:
             # Last line has the vital information.
             # If this is a FORWARD REF, we need to follow the UDT pointer
             # to get the actual class details.
-            if (match := self.CLASS_NAME_RE.match(line)) is not None:
+            elif (match := self.CLASS_NAME_RE.match(line)) is not None:
                 self._set("name", match.group("name"))
                 self._set("udt", normalize_type_id(match.group("udt")))
                 self._set("size", int(match.group("size")))