From 3f03940fcb23602353180ca5d7ec8cbece537c7d Mon Sep 17 00:00:00 2001 From: MS Date: Sat, 23 Mar 2024 18:01:40 -0400 Subject: [PATCH] Match vtables with virtual inheritance (#717) * Match vtables with virtual inheritance * Simplify vtable name check * Thunk alert --- tools/isledecomp/isledecomp/compare/core.py | 100 +++++++++++++++++- tools/isledecomp/isledecomp/compare/db.py | 64 ++++++++++- .../isledecomp/isledecomp/cvdump/demangler.py | 51 ++++++++- tools/isledecomp/isledecomp/parser/marker.py | 21 +++- tools/isledecomp/isledecomp/parser/node.py | 2 +- tools/isledecomp/isledecomp/parser/parser.py | 12 ++- tools/isledecomp/tests/test_demangler.py | 33 +++++- tools/isledecomp/tests/test_parser.py | 43 ++++++++ tools/isledecomp/tests/test_parser_util.py | 32 ++++++ tools/requirements.txt | 1 + tools/vtable/vtable.py | 15 +++ 11 files changed, 350 insertions(+), 24 deletions(-) diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py index 3d2246e4..ec49f70b 100644 --- a/tools/isledecomp/isledecomp/compare/core.py +++ b/tools/isledecomp/isledecomp/compare/core.py @@ -86,6 +86,7 @@ def __init__( self._find_original_strings() self._match_thunks() self._match_exports() + self._find_vtordisp() def _load_cvdump(self): logger.info("Parsing %s ...", self.pdb_file) @@ -198,7 +199,7 @@ def _load_markers(self): self._db.match_variable(var.offset, var.name) for tbl in codebase.iter_vtables(): - self._db.match_vtable(tbl.offset, tbl.name) + self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class) for string in codebase.iter_strings(): # Not that we don't trust you, but we're checking the string @@ -285,10 +286,105 @@ def _match_exports(self): ): logger.debug("Matched export %s", repr(export_name)) + def _find_vtordisp(self): + """If there are any cases of virtual inheritance, we can read + through the vtables for those classes and find the vtable thunk + functions (vtordisp). + + Our approach is this: walk both vtables and check where we have a + vtordisp in the recomp table. Inspect the function at that vtable + position (in both) and check whether we jump to the same function. + + One potential pitfall here is that the virtual displacement could + differ between the thunks. We are not (yet) checking for this, so the + result is that the vtable will appear to match but we will have a diff + on the thunk in our regular function comparison. + + We could do this differently and check only the original vtable, + construct the name of the vtordisp function and match based on that.""" + + for match in self._db.get_matches_by_type(SymbolType.VTABLE): + # We need some method of identifying vtables that + # might have thunks, and this ought to work okay. + if "{for" not in match.name: + continue + + # TODO: We might want to fix this at the source (cvdump) instead. + # Any problem will be logged later when we compare the vtable. + vtable_size = 4 * (match.size // 4) + orig_table = self.orig_bin.read(match.orig_addr, vtable_size) + recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size) + + raw_addrs = zip( + [t for (t,) in struct.iter_unpack(" DiffReport: orig_raw = self.orig_bin.read(match.orig_addr, match.size) recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size) + # It's unlikely that a function other than an adjuster thunk would + # start with a SUB instruction, so alert to a possible wrong + # annotation here. + # There's probably a better place to do this, but we're reading + # the function bytes here already. + try: + if orig_raw[0] == 0x2B and recomp_raw[0] != 0x2B: + logger.warning( + "Possible thunk at 0x%x (%s)", match.orig_addr, match.name + ) + except IndexError: + pass + def orig_lookup(addr: int) -> Optional[str]: m = self._db.get_by_orig(addr) if m is None: @@ -432,7 +528,7 @@ def match_text(m: Optional[MatchInfo], raw_addr: Optional[int] = None) -> str: match_type=SymbolType.VTABLE, orig_addr=match.orig_addr, recomp_addr=match.recomp_addr, - name=f"{match.name}::`vftable'", + name=match.name, udiff=unified_diff, ratio=ratio, ) diff --git a/tools/isledecomp/isledecomp/compare/db.py b/tools/isledecomp/isledecomp/compare/db.py index 54ae8081..f758be0c 100644 --- a/tools/isledecomp/isledecomp/compare/db.py +++ b/tools/isledecomp/isledecomp/compare/db.py @@ -4,6 +4,7 @@ import logging from typing import List, Optional from isledecomp.types import SymbolType +from isledecomp.cvdump.demangler import get_vtordisp_name _SETUP_SQL = """ DROP TABLE IF EXISTS `symbols`; @@ -249,6 +250,37 @@ def get_match_options(self, addr: int) -> Optional[dict]: for (option, value) in cur.fetchall() } + def is_vtordisp(self, recomp_addr: int) -> bool: + """Check whether this function is a vtordisp based on its + decorated name. If its demangled name is missing the vtordisp + indicator, correct that.""" + row = self._db.execute( + """SELECT name, decorated_name + FROM `symbols` + WHERE recomp_addr = ?""", + (recomp_addr,), + ).fetchone() + + if row is None: + return False + + (name, decorated_name) = row + if "`vtordisp" in name: + return True + + new_name = get_vtordisp_name(decorated_name) + if new_name is None: + return False + + self._db.execute( + """UPDATE `symbols` + SET name = ? + WHERE recomp_addr = ?""", + (new_name, recomp_addr), + ) + + return True + def _find_potential_match( self, name: str, compare_type: SymbolType ) -> Optional[int]: @@ -323,12 +355,34 @@ def match_function(self, addr: int, name: str) -> bool: return did_match - def match_vtable(self, addr: int, name: str) -> bool: - did_match = self._match_on(SymbolType.VTABLE, addr, name) - if not did_match: - logger.error("Failed to find vtable for class: %s", name) + def match_vtable( + self, addr: int, name: str, base_class: Optional[str] = None + ) -> bool: + # Only allow a match against "Class:`vftable'" + # if this is the derived class. + name = ( + f"{name}::`vftable'" + if base_class is None or base_class == name + else f"{name}::`vftable'{{for `{base_class}'}}" + ) - return did_match + row = self._db.execute( + """ + SELECT recomp_addr + FROM `symbols` + WHERE orig_addr IS NULL + AND name = ? + AND (compare_type = ?) + LIMIT 1 + """, + (name, SymbolType.VTABLE.value), + ).fetchone() + + if row is not None and self.set_pair(addr, row[0], SymbolType.VTABLE): + return True + + logger.error("Failed to find vtable for class: %s", name) + return False def match_static_variable(self, addr: int, name: str, function_addr: int) -> bool: """Matching a static function variable by combining the variable name diff --git a/tools/isledecomp/isledecomp/cvdump/demangler.py b/tools/isledecomp/isledecomp/cvdump/demangler.py index 10984bff..9b2445da 100644 --- a/tools/isledecomp/isledecomp/cvdump/demangler.py +++ b/tools/isledecomp/isledecomp/cvdump/demangler.py @@ -5,6 +5,7 @@ import re from collections import namedtuple from typing import Optional +import pydemangler class InvalidEncodedNumberError(Exception): @@ -51,8 +52,52 @@ def demangle_string_const(symbol: str) -> Optional[StringConstInfo]: return StringConstInfo(len=strlen, is_utf16=is_utf16) +def get_vtordisp_name(symbol: str) -> Optional[str]: + # pylint: disable=c-extension-no-member + """For adjuster thunk functions, the PDB will sometimes use a name + that contains "vtordisp" but often will just reuse the name of the + function being thunked. We want to use the vtordisp name if possible.""" + name = pydemangler.demangle(symbol) + if name is None: + return None + + if "`vtordisp" not in name: + return None + + # Now we remove the parts of the friendly name that we don't need + try: + # Assuming this is the last of the function prefixes + thiscall_idx = name.index("__thiscall") + # To match the end of the `vtordisp{x,y}' string + end_idx = name.index("}'") + return name[thiscall_idx + 11 : end_idx + 2] + except ValueError: + return name + + def demangle_vtable(symbol: str) -> str: + # pylint: disable=c-extension-no-member """Get the class name referenced in the vtable symbol.""" + raw = pydemangler.demangle(symbol) + + if raw is None: + pass # TODO: This shouldn't happen if MSVC behaves + + # Remove storage class and other stuff we don't care about + return ( + raw.replace(" str: + """Parked implementation of MSVC symbol demangling. + We only use this for vtables and it works okay with the simple cases or + templates that refer to other classes/structs. Some namespace support. + Does not support backrefs, primitive types, or vtables with + virtual inheritance.""" # Seek ahead 4 chars to strip off "??_7" prefix t = symbol[4:].split("@") @@ -66,11 +111,11 @@ def demangle_vtable(symbol: str) -> str: else: generic = t[1][1:] - return f"{class_name}<{generic}>" + return f"{class_name}<{generic}>::`vftable'" # If we have two classes listed, it is a namespace hierarchy. # @@6B@ is a common generic suffix for these vtable symbols. if t[1] != "" and t[1] != "6B": - return t[1] + "::" + t[0] + return t[1] + "::" + t[0] + "::`vftable'" - return t[0] + return t[0] + "::`vftable'" diff --git a/tools/isledecomp/isledecomp/parser/marker.py b/tools/isledecomp/isledecomp/parser/marker.py index de8c6f05..8108ac46 100644 --- a/tools/isledecomp/isledecomp/parser/marker.py +++ b/tools/isledecomp/isledecomp/parser/marker.py @@ -1,5 +1,5 @@ import re -from typing import Optional +from typing import Optional, Tuple from enum import Enum @@ -29,18 +29,20 @@ class MarkerType(Enum): markerRegex = re.compile( - r"\s*//\s*(?P\w+):\s*(?P\w+)\s+(?P0x[a-f0-9]+)", + r"\s*//\s*(?P\w+):\s*(?P\w+)\s+(?P0x[a-f0-9]+) *(?P\S.+\S)?", flags=re.I, ) markerExactRegex = re.compile( - r"\s*// (?P[A-Z]+): (?P[A-Z0-9]+) (?P0x[a-f0-9]+)$" + r"\s*// (?P[A-Z]+): (?P[A-Z0-9]+) (?P0x[a-f0-9]+)(?: (?P\S.+\S))?\n?$" ) class DecompMarker: - def __init__(self, marker_type: str, module: str, offset: int) -> None: + def __init__( + self, marker_type: str, module: str, offset: int, extra: Optional[str] = None + ) -> None: try: self._type = MarkerType[marker_type.upper()] except KeyError: @@ -51,6 +53,7 @@ def __init__(self, marker_type: str, module: str, offset: int) -> None: # we will emit a syntax error. self._module: str = module.upper() self._offset: int = offset + self._extra: Optional[str] = extra @property def type(self) -> MarkerType: @@ -64,6 +67,10 @@ def module(self) -> str: def offset(self) -> int: return self._offset + @property + def extra(self) -> Optional[str]: + return self._extra + @property def category(self) -> MarkerCategory: if self.is_vtable(): @@ -81,6 +88,11 @@ def category(self) -> MarkerCategory: return MarkerCategory.ADDRESS + @property + def key(self) -> Tuple[str, str, Optional[str]]: + """For use with the MarkerDict. To detect/avoid marker collision.""" + return (self.category, self.module, self.extra) + def is_regular_function(self) -> bool: """Regular function, meaning: not an explicit byname lookup. FUNCTION markers can be _implicit_ byname. @@ -126,6 +138,7 @@ def match_marker(line: str) -> Optional[DecompMarker]: marker_type=match.group("type"), module=match.group("module"), offset=int(match.group("offset"), 16), + extra=match.group("extra"), ) diff --git a/tools/isledecomp/isledecomp/parser/node.py b/tools/isledecomp/isledecomp/parser/node.py index eeaed713..21e4c382 100644 --- a/tools/isledecomp/isledecomp/parser/node.py +++ b/tools/isledecomp/isledecomp/parser/node.py @@ -55,7 +55,7 @@ class ParserVariable(ParserSymbol): @dataclass class ParserVtable(ParserSymbol): - pass + base_class: Optional[str] = None @dataclass diff --git a/tools/isledecomp/isledecomp/parser/parser.py b/tools/isledecomp/isledecomp/parser/parser.py index 2e5daf07..b2534548 100644 --- a/tools/isledecomp/isledecomp/parser/parser.py +++ b/tools/isledecomp/isledecomp/parser/parser.py @@ -47,15 +47,16 @@ def __init__(self) -> None: def insert(self, marker: DecompMarker) -> bool: """Return True if this insert would overwrite""" - key = (marker.category, marker.module) - if key in self.markers: + if marker.key in self.markers: return True - self.markers[key] = marker + self.markers[marker.key] = marker return False - def query(self, category: MarkerCategory, module: str) -> Optional[DecompMarker]: - return self.markers.get((category, module)) + def query( + self, category: MarkerCategory, module: str, extra: Optional[str] = None + ) -> Optional[DecompMarker]: + return self.markers.get((category, module, extra)) def iter(self) -> Iterator[DecompMarker]: for _, marker in self.markers.items(): @@ -275,6 +276,7 @@ def _vtable_done(self, class_name: str = None): module=marker.module, offset=marker.offset, name=self.curly.get_prefix(class_name), + base_class=marker.extra, ) ) diff --git a/tools/isledecomp/tests/test_demangler.py b/tools/isledecomp/tests/test_demangler.py index e40d6e0c..f7c806e1 100644 --- a/tools/isledecomp/tests/test_demangler.py +++ b/tools/isledecomp/tests/test_demangler.py @@ -4,6 +4,7 @@ demangle_vtable, parse_encoded_number, InvalidEncodedNumberError, + get_vtordisp_name, ) string_demangle_cases = [ @@ -46,13 +47,37 @@ def test_invalid_encoded_number(): vtable_cases = [ - ("??_7LegoCarBuildAnimPresenter@@6B@", "LegoCarBuildAnimPresenter"), - ("??_7?$MxCollection@PAVLegoWorld@@@@6B@", "MxCollection"), - ("??_7?$MxPtrList@VLegoPathController@@@@6B@", "MxPtrList"), - ("??_7Renderer@Tgl@@6B@", "Tgl::Renderer"), + ("??_7LegoCarBuildAnimPresenter@@6B@", "LegoCarBuildAnimPresenter::`vftable'"), + ("??_7?$MxCollection@PAVLegoWorld@@@@6B@", "MxCollection::`vftable'"), + ( + "??_7?$MxPtrList@VLegoPathController@@@@6B@", + "MxPtrList::`vftable'", + ), + ("??_7Renderer@Tgl@@6B@", "Tgl::Renderer::`vftable'"), + ("??_7LegoExtraActor@@6B0@@", "LegoExtraActor::`vftable'{for `LegoExtraActor'}"), + ( + "??_7LegoExtraActor@@6BLegoAnimActor@@@", + "LegoExtraActor::`vftable'{for `LegoAnimActor'}", + ), + ( + "??_7LegoAnimActor@@6B?$LegoContainer@PAM@@@", + "LegoAnimActor::`vftable'{for `LegoContainer'}", + ), ] @pytest.mark.parametrize("symbol, class_name", vtable_cases) def test_vtable(symbol, class_name): assert demangle_vtable(symbol) == class_name + + +def test_vtordisp(): + """Make sure we can accurately detect an adjuster thunk symbol""" + assert get_vtordisp_name("") is None + assert get_vtordisp_name("?ClassName@LegoExtraActor@@UBEPBDXZ") is None + assert ( + get_vtordisp_name("?ClassName@LegoExtraActor@@$4PPPPPPPM@A@BEPBDXZ") is not None + ) + + # A function called vtordisp + assert get_vtordisp_name("?vtordisp@LegoExtraActor@@UBEPBDXZ") is None diff --git a/tools/isledecomp/tests/test_parser.py b/tools/isledecomp/tests/test_parser.py index 772a39c4..bbc3b739 100644 --- a/tools/isledecomp/tests/test_parser.py +++ b/tools/isledecomp/tests/test_parser.py @@ -711,3 +711,46 @@ def test_header_function_declaration(parser): assert len(parser.alerts) == 1 assert parser.alerts[0].code == ParserError.NO_IMPLEMENTATION + + +def test_extra(parser): + """Allow a fourth field in the decomp annotation. Its use will vary + depending on the marker type. Currently this is only used to identify + a vtable with virtual inheritance.""" + + # Intentionally using non-vtable markers here. + # We might want to emit a parser warning for unnecessary extra info. + parser.read_lines( + [ + "// GLOBAL: TEST 0x5555 Haha", + "int g_variable = 0;", + "// FUNCTION: TEST 0x1234 Something", + "void Test() { g_variable++; }", + "// LIBRARY: TEST 0x8080 Printf", + "// _printf", + ] + ) + + # We don't use this information (yet) but this is all fine. + assert len(parser.alerts) == 0 + + +def test_virtual_inheritance(parser): + """Indicate the base class for a vtable where the class uses + virtual inheritance.""" + parser.read_lines( + [ + "// VTABLE: HELLO 0x1234", + "// VTABLE: HELLO 0x1238 Greetings", + "// VTABLE: HELLO 0x123c Howdy", + "class HiThere : public virtual Greetings {", + "};", + ] + ) + + assert len(parser.alerts) == 0 + assert len(parser.vtables) == 3 + assert parser.vtables[0].base_class is None + assert parser.vtables[1].base_class == "Greetings" + assert parser.vtables[2].base_class == "Howdy" + assert all(v.name == "HiThere" for v in parser.vtables) diff --git a/tools/isledecomp/tests/test_parser_util.py b/tools/isledecomp/tests/test_parser_util.py index 8a403710..9936c5bc 100644 --- a/tools/isledecomp/tests/test_parser_util.py +++ b/tools/isledecomp/tests/test_parser_util.py @@ -65,6 +65,14 @@ def test_is_blank_or_comment(line: str, expected: bool): # TODO: These match but shouldn't. # (False, False, '// FUNCTION: LEGO1 0'), # (False, False, '// FUNCTION: LEGO1 0x'), + # Extra field + (True, True, "// VTABLE: HELLO 0x1234 Extra"), + # Extra with spaces + (True, True, "// VTABLE: HELLO 0x1234 Whatever"), + # Extra, no space (if the first non-hex character is not in [a-f]) + (True, False, "// VTABLE: HELLO 0x1234Hello"), + # Extra, many spaces + (True, False, "// VTABLE: HELLO 0x1234 Hello"), ] @@ -174,3 +182,27 @@ def test_get_variable_name(line: str, name: str): @pytest.mark.parametrize("line, string", string_match_cases) def test_get_string_contents(line: str, string: str): assert get_string_contents(line) == string + + +def test_marker_extra_spaces(): + """The extra field can contain spaces""" + marker = match_marker("// VTABLE: TEST 0x1234 S p a c e s") + assert marker.extra == "S p a c e s" + + # Trailing spaces removed + marker = match_marker("// VTABLE: TEST 0x8888 spaces ") + assert marker.extra == "spaces" + + # Trailing newline removed if present + marker = match_marker("// VTABLE: TEST 0x5555 newline\n") + assert marker.extra == "newline" + + +def test_marker_trailing_spaces(): + """Should ignore trailing spaces. (Invalid extra field) + Offset field not truncated, extra field set to None.""" + + marker = match_marker("// VTABLE: TEST 0x1234 ") + assert marker is not None + assert marker.offset == 0x1234 + assert marker.extra is None diff --git a/tools/requirements.txt b/tools/requirements.txt index 5875ae35..b23ca4fb 100644 --- a/tools/requirements.txt +++ b/tools/requirements.txt @@ -5,3 +5,4 @@ colorama isledecomp pystache pyyaml +git+https://github.com/wbenny/pydemangler.git diff --git a/tools/vtable/vtable.py b/tools/vtable/vtable.py index eb09412c..c5dc27a7 100755 --- a/tools/vtable/vtable.py +++ b/tools/vtable/vtable.py @@ -89,6 +89,21 @@ def main(): print_summary(vtable_count, problem_count) + # Now compare adjuster thunk functions, if there are any. + # These matches are generated by the compare engine. + # They should always match 100%. If not, there is a problem + # with the inheritance or an overriden function. + for fun_match in engine.get_functions(): + if "`vtordisp" not in fun_match.name: + continue + + diff = engine.compare_address(fun_match.orig_addr) + if diff.ratio < 1.0: + problem_count += 1 + print( + f"Problem with adjuster thunk {fun_match.name} (0x{fun_match.orig_addr:x} / 0x{fun_match.recomp_addr:x})" + ) + return 1 if problem_count > 0 else 0