diff --git a/LEGO1/library_msvc.h b/LEGO1/library_msvc.h index c46154a8..20ae06ef 100644 --- a/LEGO1/library_msvc.h +++ b/LEGO1/library_msvc.h @@ -40,14 +40,6 @@ // LIBRARY: LEGO1 0x1008b640 // _rand -// entry -// LIBRARY: ISLE 0x4082e0 -// _WinMainCRTStartup - -// entry -// LIBRARY: LEGO1 0x1008c860 -// __DllMainCRTStartup@12 - // LIBRARY: ISLE 0x409110 // __mtinit diff --git a/tools/isledecomp/isledecomp/__init__.py b/tools/isledecomp/isledecomp/__init__.py index a95f1869..59c55869 100644 --- a/tools/isledecomp/isledecomp/__init__.py +++ b/tools/isledecomp/isledecomp/__init__.py @@ -1,5 +1,4 @@ from .bin import * from .dir import * from .parser import * -from .syminfo import * from .utils import * diff --git a/tools/isledecomp/isledecomp/bin.py b/tools/isledecomp/isledecomp/bin.py index 16f70f7a..26dd00f8 100644 --- a/tools/isledecomp/isledecomp/bin.py +++ b/tools/isledecomp/isledecomp/bin.py @@ -1,3 +1,4 @@ +import logging import struct from typing import List, Optional from dataclasses import dataclass @@ -51,12 +52,17 @@ class ImageSectionHeader: number_of_line_numbers: int characteristics: int + @property + def extent(self): + """Get the highest possible offset of this section""" + return max(self.size_of_raw_data, self.virtual_size) + def match_name(self, name: str) -> bool: return self.name == struct.pack("8s", name.encode("ascii")) def contains_vaddr(self, vaddr: int) -> bool: ofs = vaddr - self.virtual_address - return 0 <= ofs < max(self.size_of_raw_data, self.virtual_size) + return 0 <= ofs < self.extent def addr_is_uninitialized(self, vaddr: int) -> bool: """We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in @@ -71,25 +77,29 @@ def addr_is_uninitialized(self, vaddr: int) -> bool: ) +logger = logging.getLogger(__name__) + + class Bin: """Parses a PE format EXE and allows reading data from a virtual address. Reference: https://learn.microsoft.com/en-us/windows/win32/debug/pe-format""" # pylint: disable=too-many-instance-attributes - def __init__(self, filename: str, logger=None) -> None: - self.logger = logger - self._debuglog(f'Parsing headers of "{filename}"... ') + def __init__(self, filename: str, find_str: bool = False) -> None: + logger.debug('Parsing headers of "%s"... ', filename) self.filename = filename self.file = None self.imagebase = None self.entry = None self.sections: List[ImageSectionHeader] = [] self.last_section = None + self.find_str = find_str + self._potential_strings = {} self._relocated_addrs = set() def __enter__(self): - self._debuglog(f"Bin {self.filename} Enter") + logger.debug("Bin %s Enter", self.filename) self.file = open(self.filename, "rb") (mz_str,) = struct.unpack("2s", self.file.read(2)) @@ -123,28 +133,71 @@ def __enter__(self): self._populate_relocations() + # This is a (semi) expensive lookup that is not necesssary in every case. + # We can find strings in the original if we have coverage using STRING markers. + # For the recomp, we can find strings using the PDB. + if self.find_str: + self._prepare_string_search() + text_section = self._get_section_by_name(".text") self.last_section = text_section - self._debuglog("... Parsing finished") + logger.debug("... Parsing finished") return self def __exit__(self, exc_type, exc_value, exc_traceback): - self._debuglog(f"Bin {self.filename} Exit") + logger.debug("Bin %s Exit", self.filename) if self.file: self.file.close() - def _debuglog(self, msg): - """Write to the logger, if present""" - if self.logger is not None: - self.logger.debug(msg) - def get_relocated_addresses(self) -> List[int]: return sorted(self._relocated_addrs) + def find_string(self, target: str) -> Optional[int]: + # Pad with null terminator to make sure we don't + # match on a subset of the full string + if not target.endswith(b"\x00"): + target += b"\x00" + + c = target[0] + if c not in self._potential_strings: + return None + + for addr in self._potential_strings[c]: + if target == self.read(addr, len(target)): + return addr + + return None + def is_relocated_addr(self, vaddr) -> bool: return vaddr in self._relocated_addrs + def _prepare_string_search(self): + """We are intersted in deduplicated string constants found in the + .rdata and .data sections. For each relocated address in these sections, + read the first byte and save the address if that byte is an ASCII character. + When we search for an arbitrary string later, we can narrow down the list + of potential locations by a lot.""" + + def is_ascii(b): + return b" " <= b < b"\x7f" + + sect_data = self._get_section_by_name(".data") + sect_rdata = self._get_section_by_name(".rdata") + potentials = filter( + lambda a: sect_data.contains_vaddr(a) or sect_rdata.contains_vaddr(a), + self.get_relocated_addresses(), + ) + + for addr in potentials: + c = self.read(addr, 1) + if c is not None and is_ascii(c): + k = ord(c) + if k not in self._potential_strings: + self._potential_strings[k] = set() + + self._potential_strings[k].add(addr) + def _populate_relocations(self): """The relocation table in .reloc gives each virtual address where the next four bytes are, itself, another virtual address. During loading, these values will be @@ -212,6 +265,9 @@ def _get_section_by_name(self, name: str): return section + def get_section_extent_by_index(self, index: int) -> int: + return self.sections[index - 1].extent + def get_section_offset_by_index(self, index: int) -> int: """The symbols output from cvdump gives addresses in this format: AAAA.BBBBBBBB where A is the index (1-based) into the section table and B is the local offset. @@ -242,6 +298,15 @@ def get_raw_addr(self, vaddr: int) -> int: + self.last_section.pointer_to_raw_data ) + def is_valid_section(self, section: int) -> bool: + """The PDB will refer to sections that are not listed in the headers + and so should ignore these references.""" + try: + _ = self.sections[section - 1] + return True + except IndexError: + return False + def is_valid_vaddr(self, vaddr: int) -> bool: """Does this virtual address point to anything in the exe?""" section = next( diff --git a/tools/isledecomp/isledecomp/compare/__init__.py b/tools/isledecomp/isledecomp/compare/__init__.py new file mode 100644 index 00000000..f8d18500 --- /dev/null +++ b/tools/isledecomp/isledecomp/compare/__init__.py @@ -0,0 +1 @@ +from .core import Compare diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py new file mode 100644 index 00000000..07860203 --- /dev/null +++ b/tools/isledecomp/isledecomp/compare/core.py @@ -0,0 +1,149 @@ +import os +import logging +from typing import List, Optional +from isledecomp.cvdump.demangler import demangle_string_const +from isledecomp.cvdump import Cvdump, CvdumpAnalysis +from isledecomp.parser import DecompCodebase +from isledecomp.dir import walk_source_dir +from isledecomp.types import SymbolType +from .db import CompareDb, MatchInfo +from .lines import LinesDb + + +logger = logging.getLogger(__name__) + + +class Compare: + # pylint: disable=too-many-instance-attributes + def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir): + self.orig_bin = orig_bin + self.recomp_bin = recomp_bin + self.pdb_file = pdb_file + self.code_dir = code_dir + + self._lines_db = LinesDb(code_dir) + self._db = CompareDb() + + self._load_cvdump() + self._load_markers() + self._find_original_strings() + + def _load_cvdump(self): + logger.info("Parsing %s ...", self.pdb_file) + cv = ( + Cvdump(self.pdb_file) + .lines() + .globals() + .publics() + .symbols() + .section_contributions() + .run() + ) + res = CvdumpAnalysis(cv) + + for sym in res.nodes: + # The PDB might contain sections that do not line up with the + # actual binary. The symbol "__except_list" is one example. + # In these cases, just skip this symbol and move on because + # we can't do much with it. + if not self.recomp_bin.is_valid_section(sym.section): + continue + + addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset) + + # If this symbol is the final one in its section, we were not able to + # estimate its size because we didn't have the total size of that section. + # We can get this estimate now and assume that the final symbol occupies + # the remainder of the section. + if sym.estimated_size is None: + sym.estimated_size = ( + self.recomp_bin.get_section_extent_by_index(sym.section) + - sym.offset + ) + + if sym.node_type == SymbolType.STRING: + string_info = demangle_string_const(sym.decorated_name) + # TODO: skip unicode for now. will need to handle these differently. + if string_info.is_utf16: + continue + + raw = self.recomp_bin.read(addr, sym.size()) + try: + sym.friendly_name = raw.decode("latin1") + except UnicodeDecodeError: + pass + + self._db.set_recomp_symbol(addr, sym.node_type, sym.name(), sym.size()) + + for lineref in cv.lines: + addr = self.recomp_bin.get_abs_addr(lineref.section, lineref.offset) + self._lines_db.add_line(lineref.filename, lineref.line_no, addr) + + # The _entry symbol is referenced in the PE header so we get this match for free. + self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry) + + def _load_markers(self): + # Guess at module name from PDB file name + # reccmp checks the original binary filename; we could use this too + (module, _) = os.path.splitext(os.path.basename(self.pdb_file)) + + codefiles = list(walk_source_dir(self.code_dir)) + codebase = DecompCodebase(codefiles, module) + + # Match lineref functions first because this is a guaranteed match. + # If we have two functions that share the same name, and one is + # a lineref, we can match the nameref correctly because the lineref + # was already removed from consideration. + for fun in codebase.iter_line_functions(): + recomp_addr = self._lines_db.search_line(fun.filename, fun.line_number) + if recomp_addr is not None: + self._db.set_function_pair(fun.offset, recomp_addr) + if fun.should_skip(): + self._db.skip_compare(fun.offset) + + for fun in codebase.iter_name_functions(): + self._db.match_function(fun.offset, fun.name) + if fun.should_skip(): + self._db.skip_compare(fun.offset) + + for var in codebase.iter_variables(): + self._db.match_variable(var.offset, var.name) + + for tbl in codebase.iter_vtables(): + self._db.match_vtable(tbl.offset, tbl.name) + + def _find_original_strings(self): + """Go to the original binary and look for the specified string constants + to find a match. This is a (relatively) expensive operation so we only + look at strings that we have not already matched via a STRING annotation.""" + + for string in self._db.get_unmatched_strings(): + addr = self.orig_bin.find_string(string.encode("latin1")) + if addr is None: + escaped = repr(string) + logger.debug("Failed to find this string in the original: %s", escaped) + continue + + self._db.match_string(addr, string) + + def get_one_function(self, addr: int) -> Optional[MatchInfo]: + """i.e. verbose mode for reccmp""" + return self._db.get_one_function(addr) + + def get_functions(self) -> List[MatchInfo]: + return self._db.get_matches(SymbolType.FUNCTION) + + def compare_functions(self): + pass + + def compare_variables(self): + pass + + def compare_pointers(self): + pass + + def compare_strings(self): + pass + + def compare_vtables(self): + pass diff --git a/tools/isledecomp/isledecomp/compare/db.py b/tools/isledecomp/isledecomp/compare/db.py new file mode 100644 index 00000000..850f25fd --- /dev/null +++ b/tools/isledecomp/isledecomp/compare/db.py @@ -0,0 +1,149 @@ +"""Wrapper for database (here an in-memory sqlite database) that collects the +addresses/symbols that we want to compare between the original and recompiled binaries.""" +import sqlite3 +import logging +from collections import namedtuple +from typing import List, Optional +from isledecomp.types import SymbolType + +_SETUP_SQL = """ + DROP TABLE IF EXISTS `symbols`; + CREATE TABLE `symbols` ( + compare_type int, + orig_addr int, + recomp_addr int, + name text, + size int, + should_skip int default(FALSE) + ); + CREATE INDEX `symbols_re` ON `symbols` (recomp_addr); + CREATE INDEX `symbols_na` ON `symbols` (compare_type, name); +""" + + +MatchInfo = namedtuple("MatchInfo", "orig_addr, recomp_addr, size, name") + + +def matchinfo_factory(_, row): + return MatchInfo(*row) + + +logger = logging.getLogger(__name__) + + +class CompareDb: + def __init__(self): + self._db = sqlite3.connect(":memory:") + self._db.executescript(_SETUP_SQL) + + def set_recomp_symbol( + self, + addr: int, + compare_type: Optional[SymbolType], + name: Optional[str], + size: Optional[int], + ): + compare_value = compare_type.value if compare_type is not None else None + self._db.execute( + "INSERT INTO `symbols` (recomp_addr, compare_type, name, size) VALUES (?,?,?,?)", + (addr, compare_value, name, size), + ) + + def get_unmatched_strings(self) -> List[str]: + """Return any strings not already identified by STRING markers.""" + + cur = self._db.execute( + "SELECT name FROM `symbols` WHERE compare_type = ? AND orig_addr IS NULL", + (SymbolType.STRING.value,), + ) + + return [string for (string,) in cur.fetchall()] + + def get_one_function(self, addr: int) -> Optional[MatchInfo]: + cur = self._db.execute( + """SELECT orig_addr, recomp_addr, size, name + FROM `symbols` + WHERE compare_type = ? + AND orig_addr = ? + AND recomp_addr IS NOT NULL + AND should_skip IS FALSE + ORDER BY orig_addr + """, + (SymbolType.FUNCTION.value, addr), + ) + cur.row_factory = matchinfo_factory + return cur.fetchone() + + def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]: + cur = self._db.execute( + """SELECT orig_addr, recomp_addr, size, name + FROM `symbols` + WHERE compare_type = ? + AND orig_addr IS NOT NULL + AND recomp_addr IS NOT NULL + AND should_skip IS FALSE + ORDER BY orig_addr + """, + (compare_type.value,), + ) + cur.row_factory = matchinfo_factory + + return cur.fetchall() + + def set_function_pair(self, orig: int, recomp: int) -> bool: + """For lineref match or _entry""" + cur = self._db.execute( + "UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?", + (orig, SymbolType.FUNCTION.value, recomp), + ) + + return cur.rowcount > 0 + # TODO: Both ways required? + + def skip_compare(self, orig: int): + self._db.execute( + "UPDATE `symbols` SET should_skip = TRUE WHERE orig_addr = ?", (orig,) + ) + + def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool: + # Update the compare_type here too since the marker tells us what we should do + logger.debug("Looking for %s %s", compare_type.name.lower(), name) + cur = self._db.execute( + """UPDATE `symbols` + SET orig_addr = ?, compare_type = ? + WHERE name = ? + AND orig_addr IS NULL + AND (compare_type = ? OR compare_type IS NULL)""", + (addr, compare_type.value, name, compare_type.value), + ) + + return cur.rowcount > 0 + + def match_function(self, addr: int, name: str) -> bool: + did_match = self._match_on(SymbolType.FUNCTION, addr, name) + if not did_match: + logger.error("Failed to find function symbol with name: %s", name) + + return did_match + + def match_vtable(self, addr: int, name: str) -> bool: + did_match = self._match_on(SymbolType.VTABLE, addr, name) + if not did_match: + logger.error("Failed to find vtable for class: %s", name) + + return did_match + + def match_variable(self, addr: int, name: str) -> bool: + did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on( + SymbolType.POINTER, addr, name + ) + if not did_match: + logger.error("Failed to find variable: %s", name) + + return did_match + + def match_string(self, addr: int, value: str) -> bool: + did_match = self._match_on(SymbolType.STRING, addr, value) + if not did_match: + escaped = repr(value) + logger.error("Failed to find string: %s", escaped) diff --git a/tools/isledecomp/isledecomp/compare/lines.py b/tools/isledecomp/isledecomp/compare/lines.py new file mode 100644 index 00000000..ced3c117 --- /dev/null +++ b/tools/isledecomp/isledecomp/compare/lines.py @@ -0,0 +1,58 @@ +"""Database used to match (filename, line_number) pairs +between FUNCTION markers and PDB analysis.""" +import sqlite3 +import logging +from typing import Optional +from pathlib import Path +from isledecomp.dir import PathResolver + + +_SETUP_SQL = """ + DROP TABLE IF EXISTS `lineref`; + CREATE TABLE `lineref` ( + path text not null, + filename text not null, + line int not null, + addr int not null + ); + CREATE INDEX `file_line` ON `lineref` (filename, line); +""" + + +logger = logging.getLogger(__name__) + + +class LinesDb: + def __init__(self, code_dir) -> None: + self._db = sqlite3.connect(":memory:") + self._db.executescript(_SETUP_SQL) + self._path_resolver = PathResolver(code_dir) + + def add_line(self, path: str, line_no: int, addr: int): + """To be added from the LINES section of cvdump.""" + sourcepath = self._path_resolver.resolve_cvdump(path) + filename = Path(sourcepath).name.lower() + + self._db.execute( + "INSERT INTO `lineref` (path, filename, line, addr) VALUES (?,?,?,?)", + (sourcepath, filename, line_no, addr), + ) + + def search_line(self, path: str, line_no: int) -> Optional[int]: + """Using path and line number from FUNCTION marker, + get the address of this function in the recomp.""" + filename = Path(path).name.lower() + cur = self._db.execute( + "SELECT path, addr FROM `lineref` WHERE filename = ? AND line = ?", + (filename, line_no), + ) + for source_path, addr in cur.fetchall(): + if Path(path).samefile(source_path): + return addr + + logger.error( + "Failed to find function symbol with filename and line: %s:%d", + path, + line_no, + ) + return None diff --git a/tools/isledecomp/isledecomp/cvdump/__init__.py b/tools/isledecomp/isledecomp/cvdump/__init__.py index e9d66298..635ef5cd 100644 --- a/tools/isledecomp/isledecomp/cvdump/__init__.py +++ b/tools/isledecomp/isledecomp/cvdump/__init__.py @@ -1,2 +1,3 @@ +from .analysis import CvdumpAnalysis from .parser import CvdumpParser from .runner import Cvdump diff --git a/tools/isledecomp/isledecomp/cvdump/analysis.py b/tools/isledecomp/isledecomp/cvdump/analysis.py new file mode 100644 index 00000000..720593be --- /dev/null +++ b/tools/isledecomp/isledecomp/cvdump/analysis.py @@ -0,0 +1,184 @@ +"""For collating the results from parsing cvdump.exe into a more directly useful format.""" +from typing import List, Optional, Tuple +from isledecomp.types import SymbolType +from .parser import CvdumpParser +from .demangler import demangle_string_const, demangle_vtable + + +def data_type_info(type_name: str) -> Optional[Tuple[int, bool]]: + """cvdump type aliases are listed here: + https://github.com/microsoft/microsoft-pdb/blob/master/include/cvinfo.h + For the given type, return tuple(size, is_pointer) if possible.""" + # pylint: disable=too-many-return-statements + # TODO: refactor to be as simple as possble + + # Ignore complex types. We can get the size of those from the TYPES section. + if not type_name.startswith("T"): + return None + + # if 32-bit pointer + if type_name.startswith("T_32P"): + return (4, True) + + if type_name.endswith("QUAD") or type_name.endswith("64"): + return (8, False) + + if ( + type_name.endswith("LONG") + or type_name.endswith("INT4") + or type_name.endswith("32") + ): + return (4, False) + + if type_name.endswith("SHORT") or type_name.endswith("WCHAR"): + return (2, False) + + if "CHAR" in type_name: + return (1, False) + + if type_name in ("T_NOTYPE", "T_VOID"): + return (0, False) + + return None + + +class CvdumpNode: + # pylint: disable=too-many-instance-attributes + # These two are required and allow us to identify the symbol + section: int + offset: int + # aka the mangled name from the PUBLICS section + decorated_name: Optional[str] = None + # optional "nicer" name (e.g. of a function from SYMBOLS section) + friendly_name: Optional[str] = None + # To be determined by context after inserting data, unless the decorated + # name makes this obvious. (i.e. string constants or vtables) + # We choose not to assume that section 1 (probably ".text") contains only + # functions. Smacker functions are linked to their own section "_UNSTEXT" + node_type: Optional[SymbolType] = None + # Function size can be read from the LINES section so use this over any + # other value if we have it. + # TYPES section can tell us the size of structs and other complex types. + confirmed_size: Optional[int] = None + # Estimated by reading the distance between this symbol and the one that + # follows in the same section. + # If this is the last symbol in the section, we cannot estimate a size. + estimated_size: Optional[int] = None + # Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be + # accurate. + section_contribution: Optional[int] = None + + def __init__(self, section: int, offset: int) -> None: + self.section = section + self.offset = offset + + def set_decorated(self, name: str): + self.decorated_name = name + + if self.decorated_name.startswith("??_7"): + self.node_type = SymbolType.VTABLE + self.friendly_name = demangle_vtable(self.decorated_name) + + elif self.decorated_name.startswith("??_C@"): + self.node_type = SymbolType.STRING + (strlen, _) = demangle_string_const(self.decorated_name) + self.confirmed_size = strlen + + elif not self.decorated_name.startswith("?") and "@" in self.decorated_name: + # C mangled symbol. The trailing at-sign with number tells the number of bytes + # in the parameter list for __stdcall, __fastcall, or __vectorcall + # For __cdecl it is more ambiguous and we would have to know which section we are in. + # https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170#FormatC + self.node_type = SymbolType.FUNCTION + + def name(self) -> Optional[str]: + """Prefer "friendly" name if we have it. + This is what we have been using to match functions.""" + return self.friendly_name or self.decorated_name + + def size(self) -> Optional[int]: + if self.confirmed_size is not None: + return self.confirmed_size + + # Better to undershoot the size because we can identify a comparison gap easily + if self.estimated_size is not None and self.section_contribution is not None: + return min(self.estimated_size, self.section_contribution) + + # Return whichever one we have, or neither + return self.estimated_size or self.section_contribution + + +class CvdumpAnalysis: + """Collects the results from CvdumpParser into a list of nodes (i.e. symbols). + These can then be analyzed by a downstream tool.""" + + nodes = List[CvdumpNode] + + def __init__(self, parser: CvdumpParser): + """Read in as much information as we have from the parser. + The more sections we have, the better our information will be.""" + node_dict = {} + + # PUBLICS is our roadmap for everything that follows. + for pub in parser.publics: + key = (pub.section, pub.offset) + if key not in node_dict: + node_dict[key] = CvdumpNode(*key) + + node_dict[key].set_decorated(pub.name) + + for sizeref in parser.sizerefs: + key = (sizeref.section, sizeref.offset) + if key not in node_dict: + node_dict[key] = CvdumpNode(*key) + + node_dict[key].section_contribution = sizeref.size + + for glo in parser.globals: + key = (glo.section, glo.offset) + if key not in node_dict: + node_dict[key] = CvdumpNode(*key) + + node_dict[key].node_type = SymbolType.DATA + node_dict[key].friendly_name = glo.name + + if (g_info := data_type_info(glo.type)) is not None: + (size, is_pointer) = g_info + node_dict[key].confirmed_size = size + if is_pointer: + node_dict[key].node_type = SymbolType.POINTER + + for lin in parser.lines: + key = (lin.section, lin.offset) + # Here we only set if the section:offset already exists + # because our values include offsets inside of the function. + if key in node_dict: + node_dict[key].node_type = SymbolType.FUNCTION + + for sym in parser.symbols: + key = (sym.section, sym.offset) + if key not in node_dict: + node_dict[key] = CvdumpNode(*key) + + if sym.type == "S_GPROC32": + node_dict[key].friendly_name = sym.name + node_dict[key].confirmed_size = sym.size + node_dict[key].node_type = SymbolType.FUNCTION + + self.nodes = [v for _, v in dict(sorted(node_dict.items())).items()] + self._estimate_size() + + def _estimate_size(self): + """Get the distance between one section:offset value and the next one + in the same section. This gives a rough estimate of the size of the symbol. + If we have information from SECTION CONTRIBUTIONS, take whichever one is + less to get the best approximate size.""" + for i in range(len(self.nodes) - 1): + this_node = self.nodes[i] + next_node = self.nodes[i + 1] + + # If they are in different sections, we can't compare them + if this_node.section != next_node.section: + continue + + this_node.estimated_size = next_node.offset - this_node.offset diff --git a/tools/isledecomp/isledecomp/cvdump/demangler.py b/tools/isledecomp/isledecomp/cvdump/demangler.py new file mode 100644 index 00000000..340f19a3 --- /dev/null +++ b/tools/isledecomp/isledecomp/cvdump/demangler.py @@ -0,0 +1,74 @@ +"""For demangling a subset of MSVC mangled symbols. +Some unofficial information about the mangling scheme is here: +https://en.wikiversity.org/wiki/Visual_C%2B%2B_name_mangling +""" +import re +from collections import namedtuple + + +class InvalidEncodedNumberError(Exception): + pass + + +_encoded_number_translate = str.maketrans("ABCDEFGHIJKLMNOP", "0123456789ABCDEF") + + +def parse_encoded_number(string: str) -> int: + # TODO: assert string ends in "@"? + if string.endswith("@"): + string = string[:-1] + + try: + return int(string.translate(_encoded_number_translate), 16) + except ValueError as e: + raise InvalidEncodedNumberError(string) from e + + +string_const_regex = re.compile( + r"\?\?_C@\_(?P[0-1])(?P\d|[A-P]+@)(?P\w+)@(?P.+)@" +) +StringConstInfo = namedtuple("StringConstInfo", "len is_utf16") + + +def demangle_string_const(symbol: str) -> StringConstInfo: + """Don't bother to decode the string text from the symbol. + We can just read it from the binary once we have the length.""" + match = string_const_regex.match(symbol) + if match is None: + # See below + return StringConstInfo(0, False) + + try: + strlen = ( + parse_encoded_number(match.group("len")) + if "@" in match.group("len") + else int(match.group("len")) + ) + except (ValueError, InvalidEncodedNumberError): + # This would be an annoying error to fail on if we get a bad symbol. + # For now, just assume a zero length string because this will probably + # raise some eyebrows during the comparison. + strlen = 0 + + is_utf16 = match.group("is_utf16") == "1" + return StringConstInfo(len=strlen, is_utf16=is_utf16) + + +def demangle_vtable(symbol: str) -> str: + """Get the class name referenced in the vtable symbol.""" + + # Seek ahead 4 chars to strip off "??_7" prefix + t = symbol[4:].split("@") + # "?$" indicates a template class + if t[0].startswith("?$"): + class_name = t[0][2:] + # PA = Pointer/reference + # V or U = class or struct + if t[1].startswith("PA"): + generic = f"{t[1][3:]} *" + else: + generic = t[1][1:] + + return f"{class_name}<{generic}>" + + return t[0] diff --git a/tools/isledecomp/isledecomp/cvdump/parser.py b/tools/isledecomp/isledecomp/cvdump/parser.py index ddc2e5f7..613cd4a4 100644 --- a/tools/isledecomp/isledecomp/cvdump/parser.py +++ b/tools/isledecomp/isledecomp/cvdump/parser.py @@ -1,5 +1,5 @@ import re -from typing import Iterable +from typing import Iterable, Tuple from collections import namedtuple # e.g. `*** PUBLICS` @@ -40,17 +40,28 @@ ) -LinesEntry = namedtuple("LinesEntry", "filename line_no addr") +# User functions only +LinesEntry = namedtuple("LinesEntry", "filename line_no section offset") + +# Strings, vtables, functions +# superset of everything else +# only place you can find the C symbols (library functions, smacker, etc) PublicsEntry = namedtuple("PublicsEntry", "type section offset flags name") + +# S_GPROC32 = functions SymbolsEntry = namedtuple("SymbolsEntry", "type section offset size name") + +# (Estimated) size of any symbol SizeRefEntry = namedtuple("SizeRefEntry", "section offset size") + +# global variables GdataEntry = namedtuple("GdataEntry", "section offset type name") class CvdumpParser: def __init__(self) -> None: self._section: str = "" - self._lines_filename: str = "" + self._lines_function: Tuple[str, int] = ("", 0) self.lines = [] self.publics = [] @@ -64,19 +75,24 @@ def _lines_section(self, line: str): we are in.""" # Subheader indicates a new function and possibly a new code filename. + # Save the section here because it is not given on the lines that follow. if (match := _lines_subsection_header.match(line)) is not None: - self._lines_filename = match.group(1) + self._lines_function = ( + match.group("filename"), + int(match.group("section"), 16), + ) return - if (matches := _line_addr_pairs_findall.findall(line)) is not None: - for line_no, addr in matches: - self.lines.append( - LinesEntry( - filename=self._lines_filename, - line_no=int(line_no), - addr=int(addr, 16), - ) + # Match any pairs as we find them + for line_no, offset in _line_addr_pairs_findall.findall(line): + self.lines.append( + LinesEntry( + filename=self._lines_function[0], + line_no=int(line_no), + section=self._lines_function[1], + offset=int(offset, 16), ) + ) def _publics_section(self, line: str): """Match each line from PUBLICS and pull out the symbol information. diff --git a/tools/isledecomp/isledecomp/parser/__init__.py b/tools/isledecomp/isledecomp/parser/__init__.py index 3034a562..14549700 100644 --- a/tools/isledecomp/isledecomp/parser/__init__.py +++ b/tools/isledecomp/isledecomp/parser/__init__.py @@ -1,2 +1,3 @@ +from .codebase import DecompCodebase from .parser import DecompParser from .linter import DecompLinter diff --git a/tools/isledecomp/isledecomp/parser/codebase.py b/tools/isledecomp/isledecomp/parser/codebase.py new file mode 100644 index 00000000..54ab4035 --- /dev/null +++ b/tools/isledecomp/isledecomp/parser/codebase.py @@ -0,0 +1,44 @@ +"""For aggregating decomp markers read from an entire directory and for a single module.""" +from typing import Iterable, Iterator, List +from .parser import DecompParser +from .node import ( + ParserSymbol, + ParserFunction, + ParserVtable, + ParserVariable, +) + + +class DecompCodebase: + def __init__(self, filenames: Iterable[str], module: str) -> None: + self._symbols: List[ParserSymbol] = [] + + parser = DecompParser() + for filename in filenames: + parser.reset() + with open(filename, "r", encoding="utf-8") as f: + parser.read_lines(f) + + for sym in parser.iter_symbols(module): + sym.filename = filename + self._symbols.append(sym) + + def iter_line_functions(self) -> Iterator[ParserFunction]: + """Return lineref functions separately from nameref. Assuming the PDB matches + the state of the source code, a line reference is a guaranteed match, even if + multiple functions share the same name. (i.e. polymorphism)""" + return filter( + lambda s: isinstance(s, ParserFunction) and not s.is_nameref(), + self._symbols, + ) + + def iter_name_functions(self) -> Iterator[ParserFunction]: + return filter( + lambda s: isinstance(s, ParserFunction) and s.is_nameref(), self._symbols + ) + + def iter_vtables(self) -> Iterator[ParserVtable]: + return filter(lambda s: isinstance(s, ParserVtable), self._symbols) + + def iter_variables(self) -> Iterator[ParserVariable]: + return filter(lambda s: isinstance(s, ParserVariable), self._symbols) diff --git a/tools/isledecomp/isledecomp/parser/parser.py b/tools/isledecomp/isledecomp/parser/parser.py index f9485e65..d4153b3c 100644 --- a/tools/isledecomp/isledecomp/parser/parser.py +++ b/tools/isledecomp/isledecomp/parser/parser.py @@ -1,6 +1,6 @@ # C++ file parser -from typing import List, Iterable, Iterator +from typing import List, Iterable, Iterator, Optional from enum import Enum from .util import ( is_blank_or_comment, @@ -122,6 +122,11 @@ def vtables(self) -> List[ParserSymbol]: def variables(self) -> List[ParserSymbol]: return [s for s in self._symbols if isinstance(s, ParserVariable)] + def iter_symbols(self, module: Optional[str] = None) -> Iterator[ParserSymbol]: + for s in self._symbols: + if module is None or s.module == module: + yield s + def _recover(self): """We hit a syntax error and need to reset temp structures""" self.state = ReaderState.SEARCH diff --git a/tools/isledecomp/isledecomp/syminfo.py b/tools/isledecomp/isledecomp/syminfo.py deleted file mode 100644 index 8388eaa5..00000000 --- a/tools/isledecomp/isledecomp/syminfo.py +++ /dev/null @@ -1,107 +0,0 @@ -import os -from isledecomp.dir import PathResolver -from isledecomp.cvdump import Cvdump - - -class RecompiledInfo: - addr = None - size = None - name = None - start = None - - -# Declare a class that parses the output of cvdump for fast access later -class SymInfo: - funcs = {} - lines = {} - names = {} - - def __init__(self, pdb, sym_recompfile, sym_logger, base_dir): - self.logger = sym_logger - path_resolver = PathResolver(base_dir) - - self.logger.info("Parsing %s ...", pdb) - self.logger.debug("Parsing output of cvdump.exe ...") - - cv = Cvdump(pdb).lines().symbols().publics().section_contributions().run() - - self.logger.debug("... Parsing output of cvdump.exe finished") - - contrib_dict = {(s.section, s.offset): s.size for s in cv.sizerefs} - for pub in cv.publics: - if pub.type == "S_PUB32" and (pub.section, pub.offset) in contrib_dict: - size = contrib_dict[(pub.section, pub.offset)] - - info = RecompiledInfo() - info.addr = sym_recompfile.get_abs_addr(pub.section, pub.offset) - - info.start = 0 - info.size = size - info.name = pub.name - self.names[pub.name] = info - self.funcs[pub.offset] = info - - for proc in cv.symbols: - if proc.type != "S_GPROC32": - continue - - info = RecompiledInfo() - info.addr = sym_recompfile.get_abs_addr(proc.section, proc.offset) - - info.start = 0 - info.size = proc.size - info.name = proc.name - - self.names[proc.name] = info - self.funcs[proc.offset] = info - - for sourcepath, line_no, offset in cv.lines: - sourcepath = path_resolver.resolve_cvdump(sourcepath) - - if sourcepath not in self.lines: - self.lines[sourcepath] = {} - - if line_no not in self.lines[sourcepath]: - self.lines[sourcepath][line_no] = offset - - def get_recompiled_address(self, filename, line): - recompiled_addr = None - - self.logger.debug("Looking for %s:%s", filename, line) - filename_basename = os.path.basename(filename).lower() - - for fn in self.lines: - # Sometimes a PDB is compiled with a relative path while we always have - # an absolute path. Therefore we must - try: - if os.path.basename( - fn - ).lower() == filename_basename and os.path.samefile(fn, filename): - filename = fn - break - except FileNotFoundError: - continue - - if filename in self.lines and line in self.lines[filename]: - recompiled_addr = self.lines[filename][line] - - if recompiled_addr in self.funcs: - return self.funcs[recompiled_addr] - self.logger.error( - "Failed to find function symbol with address: %x", recompiled_addr - ) - return None - self.logger.error( - "Failed to find function symbol with filename and line: %s:%s", - filename, - line, - ) - return None - - def get_recompiled_address_from_name(self, name): - self.logger.debug("Looking for %s", name) - - if name in self.names: - return self.names[name] - self.logger.error("Failed to find function symbol with name: %s", name) - return None diff --git a/tools/isledecomp/isledecomp/types.py b/tools/isledecomp/isledecomp/types.py new file mode 100644 index 00000000..4d518dd3 --- /dev/null +++ b/tools/isledecomp/isledecomp/types.py @@ -0,0 +1,12 @@ +"""Types shared by other modules""" +from enum import Enum + + +class SymbolType(Enum): + """Broadly tells us what kind of comparison is required for this symbol.""" + + FUNCTION = 1 + DATA = 2 + POINTER = 3 + STRING = 4 + VTABLE = 5 diff --git a/tools/isledecomp/tests/test_cvdump.py b/tools/isledecomp/tests/test_cvdump.py new file mode 100644 index 00000000..cfaff7a9 --- /dev/null +++ b/tools/isledecomp/tests/test_cvdump.py @@ -0,0 +1,39 @@ +import pytest +from isledecomp.cvdump.analysis import data_type_info + +# fmt: off +type_check_cases = [ + ("T_32PINT4", 4, True), + ("T_32PLONG", 4, True), + ("T_32PRCHAR", 4, True), + ("T_32PREAL32", 4, True), + ("T_32PUCHAR", 4, True), + ("T_32PUINT4", 4, True), + ("T_32PULONG", 4, True), + ("T_32PUSHORT", 4, True), + ("T_32PVOID", 4, True), + ("T_CHAR", 1, False), + ("T_INT4", 4, False), + ("T_LONG", 4, False), + ("T_NOTYPE", 0, False), # ? + ("T_QUAD", 8, False), + ("T_RCHAR", 1, False), + ("T_REAL32", 4, False), + ("T_REAL64", 8, False), + ("T_SHORT", 2, False), + ("T_UCHAR", 1, False), + ("T_UINT4", 4, False), + ("T_ULONG", 4, False), + ("T_UQUAD", 8, False), + ("T_USHORT", 2, False), + ("T_VOID", 0, False), # ? + ("T_WCHAR", 2, False), +] +# fmt: on + + +@pytest.mark.parametrize("type_name, size, is_pointer", type_check_cases) +def test_type_check(type_name: str, size: int, is_pointer: bool): + assert (info := data_type_info(type_name)) is not None + assert info[0] == size + assert info[1] == is_pointer diff --git a/tools/isledecomp/tests/test_demangler.py b/tools/isledecomp/tests/test_demangler.py new file mode 100644 index 00000000..5343bdcc --- /dev/null +++ b/tools/isledecomp/tests/test_demangler.py @@ -0,0 +1,55 @@ +import pytest +from isledecomp.cvdump.demangler import ( + demangle_string_const, + demangle_vtable, + parse_encoded_number, + InvalidEncodedNumberError, +) + +string_demangle_cases = [ + ("??_C@_08LIDF@December?$AA@", 8, False), + ("??_C@_0L@EGPP@english?9nz?$AA@", 11, False), + ( + "??_C@_1O@POHA@?$AA?$CI?$AAn?$AAu?$AAl?$AAl?$AA?$CJ?$AA?$AA?$AA?$AA?$AA?$AH?$AA?$AA?$AA?$AA?$AA?$AA?$AA?$9A?$AE?$;I@", + 14, + True, + ), +] + + +@pytest.mark.parametrize("symbol, strlen, is_utf16", string_demangle_cases) +def test_strings(symbol, is_utf16, strlen): + s = demangle_string_const(symbol) + assert s.len == strlen + assert s.is_utf16 == is_utf16 + + +encoded_numbers = [ + ("A@", 0), + ("AA@", 0), # would never happen? + ("P@", 15), + ("BA@", 16), + ("BCD@", 291), +] + + +@pytest.mark.parametrize("string, value", encoded_numbers) +def test_encoded_numbers(string, value): + assert parse_encoded_number(string) == value + + +def test_invalid_encoded_number(): + with pytest.raises(InvalidEncodedNumberError): + parse_encoded_number("Hello") + + +vtable_cases = [ + ("??_7LegoCarBuildAnimPresenter@@6B@", "LegoCarBuildAnimPresenter"), + ("??_7?$MxCollection@PAVLegoWorld@@@@6B@", "MxCollection"), + ("??_7?$MxPtrList@VLegoPathController@@@@6B@", "MxPtrList"), +] + + +@pytest.mark.parametrize("symbol, class_name", vtable_cases) +def test_vtable(symbol, class_name): + assert demangle_vtable(symbol) == class_name diff --git a/tools/reccmp/reccmp.py b/tools/reccmp/reccmp.py index c7db5e2c..b7027435 100755 --- a/tools/reccmp/reccmp.py +++ b/tools/reccmp/reccmp.py @@ -10,13 +10,11 @@ from isledecomp import ( Bin, - DecompParser, get_file_in_script_dir, OffsetPlaceholderGenerator, print_diff, - SymInfo, - walk_source_dir, ) +from isledecomp.compare import Compare as IsleCompare from capstone import Cs, CS_ARCH_X86, CS_MODE_32 import colorama @@ -261,7 +259,6 @@ def main(): args = parser.parse_args() logging.basicConfig(level=args.loglevel, format="[%(levelname)s] %(message)s") - logger = logging.getLogger(__name__) colorama.init() @@ -294,8 +291,13 @@ def main(): svg = args.svg - with Bin(original, logger) as origfile, Bin(recomp, logger) as recompfile: - syminfo = SymInfo(syms, recompfile, logger, source) + with Bin(original, find_str=True) as origfile, Bin(recomp) as recompfile: + if verbose is not None: + # Mute logger events from compare engine + logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL) + logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL) + + isle_compare = IsleCompare(origfile, recompfile, syms, source) print() @@ -306,151 +308,120 @@ def main(): total_effective_accuracy = 0 htmlinsert = [] - # Generate basename of original file, used in locating OFFSET lines - basename = os.path.basename(os.path.splitext(original)[0]) + matches = [] + if verbose is not None: + match = isle_compare.get_one_function(verbose) + if match is not None: + found_verbose_target = True + matches = [match] + else: + matches = isle_compare.get_functions() - parser = DecompParser() - for srcfilename in walk_source_dir(source): - parser.reset() - with open(srcfilename, "r", encoding="utf-8") as srcfile: - parser.read_lines(srcfile) + for match in matches: + # The effective_ratio is the ratio when ignoring differing register + # allocation vs the ratio is the true ratio. + ratio = 0.0 + effective_ratio = 0.0 + if match.size: + origasm = parse_asm( + capstone_disassembler, + origfile, + match.orig_addr, + match.size, + ) + recompasm = parse_asm( + capstone_disassembler, + recompfile, + match.recomp_addr, + match.size, + ) - for fun in parser.functions: - if fun.should_skip(): - continue + diff = difflib.SequenceMatcher(None, origasm, recompasm) + ratio = diff.ratio() + effective_ratio = ratio - if fun.module != basename: - continue + if ratio != 1.0: + # Check whether we can resolve register swaps which are actually + # perfect matches modulo compiler entropy. + if can_resolve_register_differences(origasm, recompasm): + effective_ratio = 1.0 + else: + ratio = 0 - addr = fun.offset - # Verbose flag handling + percenttext = f"{(effective_ratio * 100):.2f}%" + if not plain: + if effective_ratio == 1.0: + percenttext = ( + colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL + ) + elif effective_ratio > 0.8: + percenttext = ( + colorama.Fore.YELLOW + percenttext + colorama.Style.RESET_ALL + ) + else: + percenttext = ( + colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL + ) + + if effective_ratio == 1.0 and ratio != 1.0: + if plain: + percenttext += "*" + else: + percenttext += colorama.Fore.RED + "*" + colorama.Style.RESET_ALL + + if args.print_rec_addr: + addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}" + else: + addrs = hex(match.orig_addr) + + if not verbose: + print( + f" {match.name} ({addrs}) is {percenttext} similar to the original" + ) + + function_count += 1 + total_accuracy += ratio + total_effective_accuracy += effective_ratio + + if match.size: + udiff = difflib.unified_diff(origasm, recompasm, n=10) + + # If verbose, print the diff for that function to the output if verbose: - if addr == verbose: - found_verbose_target = True - else: - continue - - if fun.is_nameref(): - recinfo = syminfo.get_recompiled_address_from_name(fun.name) - if not recinfo: - continue - else: - recinfo = syminfo.get_recompiled_address( - srcfilename, fun.line_number - ) - if not recinfo: - continue - - # The effective_ratio is the ratio when ignoring differing register - # allocation vs the ratio is the true ratio. - ratio = 0.0 - effective_ratio = 0.0 - if recinfo.size: - origasm = parse_asm( - capstone_disassembler, - origfile, - addr + recinfo.start, - recinfo.size, - ) - recompasm = parse_asm( - capstone_disassembler, - recompfile, - recinfo.addr + recinfo.start, - recinfo.size, - ) - - diff = difflib.SequenceMatcher(None, origasm, recompasm) - ratio = diff.ratio() - effective_ratio = ratio - - if ratio != 1.0: - # Check whether we can resolve register swaps which are actually - # perfect matches modulo compiler entropy. - if can_resolve_register_differences(origasm, recompasm): - effective_ratio = 1.0 - else: - ratio = 0 - - percenttext = f"{(effective_ratio * 100):.2f}%" - if not plain: if effective_ratio == 1.0: - percenttext = ( - colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL - ) - elif effective_ratio > 0.8: - percenttext = ( - colorama.Fore.YELLOW - + percenttext - + colorama.Style.RESET_ALL - ) - else: - percenttext = ( - colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL - ) - - if effective_ratio == 1.0 and ratio != 1.0: - if plain: - percenttext += "*" - else: - percenttext += ( - colorama.Fore.RED + "*" + colorama.Style.RESET_ALL - ) - - if args.print_rec_addr: - addrs = f"0x{addr:x} / 0x{recinfo.addr:x}" - else: - addrs = hex(addr) - - if not verbose: - print( - f" {recinfo.name} ({addrs}) is {percenttext} similar to the original" - ) - - function_count += 1 - total_accuracy += ratio - total_effective_accuracy += effective_ratio - - if recinfo.size: - udiff = difflib.unified_diff(origasm, recompasm, n=10) - - # If verbose, print the diff for that function to the output - if verbose: - if effective_ratio == 1.0: - ok_text = ( - "OK!" - if plain - else ( - colorama.Fore.GREEN - + "✨ OK! ✨" - + colorama.Style.RESET_ALL - ) + ok_text = ( + "OK!" + if plain + else ( + colorama.Fore.GREEN + + "✨ OK! ✨" + + colorama.Style.RESET_ALL ) - if ratio == 1.0: - print( - f"{addrs}: {recinfo.name} 100% match.\n\n{ok_text}\n\n" - ) - else: - print( - f"{addrs}: {recinfo.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n" - ) + ) + if ratio == 1.0: + print(f"{addrs}: {match.name} 100% match.\n\n{ok_text}\n\n") else: - print_diff(udiff, plain) - print( - f"\n{recinfo.name} is only {percenttext} similar to the original, diff above" + f"{addrs}: {match.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n" ) + else: + print_diff(udiff, plain) - # If html, record the diffs to an HTML file - if html_path: - htmlinsert.append( - { - "address": f"0x{addr:x}", - "name": recinfo.name, - "matching": effective_ratio, - "diff": "\n".join(udiff), - } + print( + f"\n{match.name} is only {percenttext} similar to the original, diff above" ) + # If html, record the diffs to an HTML file + if html_path: + htmlinsert.append( + { + "address": f"0x{match.orig_addr:x}", + "name": match.name, + "matching": effective_ratio, + "diff": "\n".join(udiff), + } + ) + if html_path: gen_html(html_path, json.dumps(htmlinsert))