Match vtables with virtual inheritance (#717)

* Match vtables with virtual inheritance

* Simplify vtable name check

* Thunk alert
This commit is contained in:
MS 2024-03-23 18:01:40 -04:00 committed by GitHub
parent b279e8b8b9
commit 3f03940fcb
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 350 additions and 24 deletions

View file

@ -86,6 +86,7 @@ def __init__(
self._find_original_strings() self._find_original_strings()
self._match_thunks() self._match_thunks()
self._match_exports() self._match_exports()
self._find_vtordisp()
def _load_cvdump(self): def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file) logger.info("Parsing %s ...", self.pdb_file)
@ -198,7 +199,7 @@ def _load_markers(self):
self._db.match_variable(var.offset, var.name) self._db.match_variable(var.offset, var.name)
for tbl in codebase.iter_vtables(): for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name) self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class)
for string in codebase.iter_strings(): for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string # Not that we don't trust you, but we're checking the string
@ -285,10 +286,105 @@ def _match_exports(self):
): ):
logger.debug("Matched export %s", repr(export_name)) logger.debug("Matched export %s", repr(export_name))
def _find_vtordisp(self):
"""If there are any cases of virtual inheritance, we can read
through the vtables for those classes and find the vtable thunk
functions (vtordisp).
Our approach is this: walk both vtables and check where we have a
vtordisp in the recomp table. Inspect the function at that vtable
position (in both) and check whether we jump to the same function.
One potential pitfall here is that the virtual displacement could
differ between the thunks. We are not (yet) checking for this, so the
result is that the vtable will appear to match but we will have a diff
on the thunk in our regular function comparison.
We could do this differently and check only the original vtable,
construct the name of the vtordisp function and match based on that."""
for match in self._db.get_matches_by_type(SymbolType.VTABLE):
# We need some method of identifying vtables that
# might have thunks, and this ought to work okay.
if "{for" not in match.name:
continue
# TODO: We might want to fix this at the source (cvdump) instead.
# Any problem will be logged later when we compare the vtable.
vtable_size = 4 * (match.size // 4)
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
raw_addrs = zip(
[t for (t,) in struct.iter_unpack("<L", orig_table)],
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
)
# Now walk both vtables looking for thunks.
for orig_addr, recomp_addr in raw_addrs:
if not self._db.is_vtordisp(recomp_addr):
continue
thunk_fn = self.get_by_recomp(recomp_addr)
# Read the function bytes here.
# In practice, the adjuster thunk will be under 16 bytes.
# If we have thunks of unequal size, we can still tell whether
# they are thunking the same function by grabbing the
# JMP instruction at the end.
thunk_presumed_size = max(thunk_fn.size, 16)
# Strip off MSVC padding 0xcc bytes.
# This should be safe to do; it is highly unlikely that
# the MSB of the jump displacement would be 0xcc. (huge jump)
orig_thunk_bin = self.orig_bin.read(
orig_addr, thunk_presumed_size
).rstrip(b"\xcc")
recomp_thunk_bin = self.recomp_bin.read(
recomp_addr, thunk_presumed_size
).rstrip(b"\xcc")
# Read jump opcode and displacement (last 5 bytes)
(orig_jmp, orig_disp) = struct.unpack("<Bi", orig_thunk_bin[-5:])
(recomp_jmp, recomp_disp) = struct.unpack("<Bi", recomp_thunk_bin[-5:])
# Make sure it's a JMP
if orig_jmp != 0xE9 or recomp_jmp != 0xE9:
continue
# Calculate jump destination from the end of the JMP instruction
# i.e. the end of the function
orig_actual = orig_addr + len(orig_thunk_bin) + orig_disp
recomp_actual = recomp_addr + len(recomp_thunk_bin) + recomp_disp
# If they are thunking the same function, then this must be a match.
if self.is_pointer_match(orig_actual, recomp_actual):
if len(orig_thunk_bin) != len(recomp_thunk_bin):
logger.warning(
"Adjuster thunk %s (0x%x) is not exact",
thunk_fn.name,
orig_addr,
)
self._db.set_function_pair(orig_addr, recomp_addr)
def _compare_function(self, match: MatchInfo) -> DiffReport: def _compare_function(self, match: MatchInfo) -> DiffReport:
orig_raw = self.orig_bin.read(match.orig_addr, match.size) orig_raw = self.orig_bin.read(match.orig_addr, match.size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size) recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
# It's unlikely that a function other than an adjuster thunk would
# start with a SUB instruction, so alert to a possible wrong
# annotation here.
# There's probably a better place to do this, but we're reading
# the function bytes here already.
try:
if orig_raw[0] == 0x2B and recomp_raw[0] != 0x2B:
logger.warning(
"Possible thunk at 0x%x (%s)", match.orig_addr, match.name
)
except IndexError:
pass
def orig_lookup(addr: int) -> Optional[str]: def orig_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_orig(addr) m = self._db.get_by_orig(addr)
if m is None: if m is None:
@ -432,7 +528,7 @@ def match_text(m: Optional[MatchInfo], raw_addr: Optional[int] = None) -> str:
match_type=SymbolType.VTABLE, match_type=SymbolType.VTABLE,
orig_addr=match.orig_addr, orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr, recomp_addr=match.recomp_addr,
name=f"{match.name}::`vftable'", name=match.name,
udiff=unified_diff, udiff=unified_diff,
ratio=ratio, ratio=ratio,
) )

View file

@ -4,6 +4,7 @@
import logging import logging
from typing import List, Optional from typing import List, Optional
from isledecomp.types import SymbolType from isledecomp.types import SymbolType
from isledecomp.cvdump.demangler import get_vtordisp_name
_SETUP_SQL = """ _SETUP_SQL = """
DROP TABLE IF EXISTS `symbols`; DROP TABLE IF EXISTS `symbols`;
@ -249,6 +250,37 @@ def get_match_options(self, addr: int) -> Optional[dict]:
for (option, value) in cur.fetchall() for (option, value) in cur.fetchall()
} }
def is_vtordisp(self, recomp_addr: int) -> bool:
"""Check whether this function is a vtordisp based on its
decorated name. If its demangled name is missing the vtordisp
indicator, correct that."""
row = self._db.execute(
"""SELECT name, decorated_name
FROM `symbols`
WHERE recomp_addr = ?""",
(recomp_addr,),
).fetchone()
if row is None:
return False
(name, decorated_name) = row
if "`vtordisp" in name:
return True
new_name = get_vtordisp_name(decorated_name)
if new_name is None:
return False
self._db.execute(
"""UPDATE `symbols`
SET name = ?
WHERE recomp_addr = ?""",
(new_name, recomp_addr),
)
return True
def _find_potential_match( def _find_potential_match(
self, name: str, compare_type: SymbolType self, name: str, compare_type: SymbolType
) -> Optional[int]: ) -> Optional[int]:
@ -323,12 +355,34 @@ def match_function(self, addr: int, name: str) -> bool:
return did_match return did_match
def match_vtable(self, addr: int, name: str) -> bool: def match_vtable(
did_match = self._match_on(SymbolType.VTABLE, addr, name) self, addr: int, name: str, base_class: Optional[str] = None
if not did_match: ) -> bool:
logger.error("Failed to find vtable for class: %s", name) # Only allow a match against "Class:`vftable'"
# if this is the derived class.
name = (
f"{name}::`vftable'"
if base_class is None or base_class == name
else f"{name}::`vftable'{{for `{base_class}'}}"
)
return did_match row = self._db.execute(
"""
SELECT recomp_addr
FROM `symbols`
WHERE orig_addr IS NULL
AND name = ?
AND (compare_type = ?)
LIMIT 1
""",
(name, SymbolType.VTABLE.value),
).fetchone()
if row is not None and self.set_pair(addr, row[0], SymbolType.VTABLE):
return True
logger.error("Failed to find vtable for class: %s", name)
return False
def match_static_variable(self, addr: int, name: str, function_addr: int) -> bool: def match_static_variable(self, addr: int, name: str, function_addr: int) -> bool:
"""Matching a static function variable by combining the variable name """Matching a static function variable by combining the variable name

View file

@ -5,6 +5,7 @@
import re import re
from collections import namedtuple from collections import namedtuple
from typing import Optional from typing import Optional
import pydemangler
class InvalidEncodedNumberError(Exception): class InvalidEncodedNumberError(Exception):
@ -51,8 +52,52 @@ def demangle_string_const(symbol: str) -> Optional[StringConstInfo]:
return StringConstInfo(len=strlen, is_utf16=is_utf16) return StringConstInfo(len=strlen, is_utf16=is_utf16)
def get_vtordisp_name(symbol: str) -> Optional[str]:
# pylint: disable=c-extension-no-member
"""For adjuster thunk functions, the PDB will sometimes use a name
that contains "vtordisp" but often will just reuse the name of the
function being thunked. We want to use the vtordisp name if possible."""
name = pydemangler.demangle(symbol)
if name is None:
return None
if "`vtordisp" not in name:
return None
# Now we remove the parts of the friendly name that we don't need
try:
# Assuming this is the last of the function prefixes
thiscall_idx = name.index("__thiscall")
# To match the end of the `vtordisp{x,y}' string
end_idx = name.index("}'")
return name[thiscall_idx + 11 : end_idx + 2]
except ValueError:
return name
def demangle_vtable(symbol: str) -> str: def demangle_vtable(symbol: str) -> str:
# pylint: disable=c-extension-no-member
"""Get the class name referenced in the vtable symbol.""" """Get the class name referenced in the vtable symbol."""
raw = pydemangler.demangle(symbol)
if raw is None:
pass # TODO: This shouldn't happen if MSVC behaves
# Remove storage class and other stuff we don't care about
return (
raw.replace("<class ", "<")
.replace("<struct ", "<")
.replace("const ", "")
.replace("volatile ", "")
)
def demangle_vtable_ourselves(symbol: str) -> str:
"""Parked implementation of MSVC symbol demangling.
We only use this for vtables and it works okay with the simple cases or
templates that refer to other classes/structs. Some namespace support.
Does not support backrefs, primitive types, or vtables with
virtual inheritance."""
# Seek ahead 4 chars to strip off "??_7" prefix # Seek ahead 4 chars to strip off "??_7" prefix
t = symbol[4:].split("@") t = symbol[4:].split("@")
@ -66,11 +111,11 @@ def demangle_vtable(symbol: str) -> str:
else: else:
generic = t[1][1:] generic = t[1][1:]
return f"{class_name}<{generic}>" return f"{class_name}<{generic}>::`vftable'"
# If we have two classes listed, it is a namespace hierarchy. # If we have two classes listed, it is a namespace hierarchy.
# @@6B@ is a common generic suffix for these vtable symbols. # @@6B@ is a common generic suffix for these vtable symbols.
if t[1] != "" and t[1] != "6B": if t[1] != "" and t[1] != "6B":
return t[1] + "::" + t[0] return t[1] + "::" + t[0] + "::`vftable'"
return t[0] return t[0] + "::`vftable'"

View file

@ -1,5 +1,5 @@
import re import re
from typing import Optional from typing import Optional, Tuple
from enum import Enum from enum import Enum
@ -29,18 +29,20 @@ class MarkerType(Enum):
markerRegex = re.compile( markerRegex = re.compile(
r"\s*//\s*(?P<type>\w+):\s*(?P<module>\w+)\s+(?P<offset>0x[a-f0-9]+)", r"\s*//\s*(?P<type>\w+):\s*(?P<module>\w+)\s+(?P<offset>0x[a-f0-9]+) *(?P<extra>\S.+\S)?",
flags=re.I, flags=re.I,
) )
markerExactRegex = re.compile( markerExactRegex = re.compile(
r"\s*// (?P<type>[A-Z]+): (?P<module>[A-Z0-9]+) (?P<offset>0x[a-f0-9]+)$" r"\s*// (?P<type>[A-Z]+): (?P<module>[A-Z0-9]+) (?P<offset>0x[a-f0-9]+)(?: (?P<extra>\S.+\S))?\n?$"
) )
class DecompMarker: class DecompMarker:
def __init__(self, marker_type: str, module: str, offset: int) -> None: def __init__(
self, marker_type: str, module: str, offset: int, extra: Optional[str] = None
) -> None:
try: try:
self._type = MarkerType[marker_type.upper()] self._type = MarkerType[marker_type.upper()]
except KeyError: except KeyError:
@ -51,6 +53,7 @@ def __init__(self, marker_type: str, module: str, offset: int) -> None:
# we will emit a syntax error. # we will emit a syntax error.
self._module: str = module.upper() self._module: str = module.upper()
self._offset: int = offset self._offset: int = offset
self._extra: Optional[str] = extra
@property @property
def type(self) -> MarkerType: def type(self) -> MarkerType:
@ -64,6 +67,10 @@ def module(self) -> str:
def offset(self) -> int: def offset(self) -> int:
return self._offset return self._offset
@property
def extra(self) -> Optional[str]:
return self._extra
@property @property
def category(self) -> MarkerCategory: def category(self) -> MarkerCategory:
if self.is_vtable(): if self.is_vtable():
@ -81,6 +88,11 @@ def category(self) -> MarkerCategory:
return MarkerCategory.ADDRESS return MarkerCategory.ADDRESS
@property
def key(self) -> Tuple[str, str, Optional[str]]:
"""For use with the MarkerDict. To detect/avoid marker collision."""
return (self.category, self.module, self.extra)
def is_regular_function(self) -> bool: def is_regular_function(self) -> bool:
"""Regular function, meaning: not an explicit byname lookup. FUNCTION """Regular function, meaning: not an explicit byname lookup. FUNCTION
markers can be _implicit_ byname. markers can be _implicit_ byname.
@ -126,6 +138,7 @@ def match_marker(line: str) -> Optional[DecompMarker]:
marker_type=match.group("type"), marker_type=match.group("type"),
module=match.group("module"), module=match.group("module"),
offset=int(match.group("offset"), 16), offset=int(match.group("offset"), 16),
extra=match.group("extra"),
) )

View file

@ -55,7 +55,7 @@ class ParserVariable(ParserSymbol):
@dataclass @dataclass
class ParserVtable(ParserSymbol): class ParserVtable(ParserSymbol):
pass base_class: Optional[str] = None
@dataclass @dataclass

View file

@ -47,15 +47,16 @@ def __init__(self) -> None:
def insert(self, marker: DecompMarker) -> bool: def insert(self, marker: DecompMarker) -> bool:
"""Return True if this insert would overwrite""" """Return True if this insert would overwrite"""
key = (marker.category, marker.module) if marker.key in self.markers:
if key in self.markers:
return True return True
self.markers[key] = marker self.markers[marker.key] = marker
return False return False
def query(self, category: MarkerCategory, module: str) -> Optional[DecompMarker]: def query(
return self.markers.get((category, module)) self, category: MarkerCategory, module: str, extra: Optional[str] = None
) -> Optional[DecompMarker]:
return self.markers.get((category, module, extra))
def iter(self) -> Iterator[DecompMarker]: def iter(self) -> Iterator[DecompMarker]:
for _, marker in self.markers.items(): for _, marker in self.markers.items():
@ -275,6 +276,7 @@ def _vtable_done(self, class_name: str = None):
module=marker.module, module=marker.module,
offset=marker.offset, offset=marker.offset,
name=self.curly.get_prefix(class_name), name=self.curly.get_prefix(class_name),
base_class=marker.extra,
) )
) )

View file

@ -4,6 +4,7 @@
demangle_vtable, demangle_vtable,
parse_encoded_number, parse_encoded_number,
InvalidEncodedNumberError, InvalidEncodedNumberError,
get_vtordisp_name,
) )
string_demangle_cases = [ string_demangle_cases = [
@ -46,13 +47,37 @@ def test_invalid_encoded_number():
vtable_cases = [ vtable_cases = [
("??_7LegoCarBuildAnimPresenter@@6B@", "LegoCarBuildAnimPresenter"), ("??_7LegoCarBuildAnimPresenter@@6B@", "LegoCarBuildAnimPresenter::`vftable'"),
("??_7?$MxCollection@PAVLegoWorld@@@@6B@", "MxCollection<LegoWorld *>"), ("??_7?$MxCollection@PAVLegoWorld@@@@6B@", "MxCollection<LegoWorld *>::`vftable'"),
("??_7?$MxPtrList@VLegoPathController@@@@6B@", "MxPtrList<LegoPathController>"), (
("??_7Renderer@Tgl@@6B@", "Tgl::Renderer"), "??_7?$MxPtrList@VLegoPathController@@@@6B@",
"MxPtrList<LegoPathController>::`vftable'",
),
("??_7Renderer@Tgl@@6B@", "Tgl::Renderer::`vftable'"),
("??_7LegoExtraActor@@6B0@@", "LegoExtraActor::`vftable'{for `LegoExtraActor'}"),
(
"??_7LegoExtraActor@@6BLegoAnimActor@@@",
"LegoExtraActor::`vftable'{for `LegoAnimActor'}",
),
(
"??_7LegoAnimActor@@6B?$LegoContainer@PAM@@@",
"LegoAnimActor::`vftable'{for `LegoContainer<float *>'}",
),
] ]
@pytest.mark.parametrize("symbol, class_name", vtable_cases) @pytest.mark.parametrize("symbol, class_name", vtable_cases)
def test_vtable(symbol, class_name): def test_vtable(symbol, class_name):
assert demangle_vtable(symbol) == class_name assert demangle_vtable(symbol) == class_name
def test_vtordisp():
"""Make sure we can accurately detect an adjuster thunk symbol"""
assert get_vtordisp_name("") is None
assert get_vtordisp_name("?ClassName@LegoExtraActor@@UBEPBDXZ") is None
assert (
get_vtordisp_name("?ClassName@LegoExtraActor@@$4PPPPPPPM@A@BEPBDXZ") is not None
)
# A function called vtordisp
assert get_vtordisp_name("?vtordisp@LegoExtraActor@@UBEPBDXZ") is None

View file

@ -711,3 +711,46 @@ def test_header_function_declaration(parser):
assert len(parser.alerts) == 1 assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.NO_IMPLEMENTATION assert parser.alerts[0].code == ParserError.NO_IMPLEMENTATION
def test_extra(parser):
"""Allow a fourth field in the decomp annotation. Its use will vary
depending on the marker type. Currently this is only used to identify
a vtable with virtual inheritance."""
# Intentionally using non-vtable markers here.
# We might want to emit a parser warning for unnecessary extra info.
parser.read_lines(
[
"// GLOBAL: TEST 0x5555 Haha",
"int g_variable = 0;",
"// FUNCTION: TEST 0x1234 Something",
"void Test() { g_variable++; }",
"// LIBRARY: TEST 0x8080 Printf",
"// _printf",
]
)
# We don't use this information (yet) but this is all fine.
assert len(parser.alerts) == 0
def test_virtual_inheritance(parser):
"""Indicate the base class for a vtable where the class uses
virtual inheritance."""
parser.read_lines(
[
"// VTABLE: HELLO 0x1234",
"// VTABLE: HELLO 0x1238 Greetings",
"// VTABLE: HELLO 0x123c Howdy",
"class HiThere : public virtual Greetings {",
"};",
]
)
assert len(parser.alerts) == 0
assert len(parser.vtables) == 3
assert parser.vtables[0].base_class is None
assert parser.vtables[1].base_class == "Greetings"
assert parser.vtables[2].base_class == "Howdy"
assert all(v.name == "HiThere" for v in parser.vtables)

View file

@ -65,6 +65,14 @@ def test_is_blank_or_comment(line: str, expected: bool):
# TODO: These match but shouldn't. # TODO: These match but shouldn't.
# (False, False, '// FUNCTION: LEGO1 0'), # (False, False, '// FUNCTION: LEGO1 0'),
# (False, False, '// FUNCTION: LEGO1 0x'), # (False, False, '// FUNCTION: LEGO1 0x'),
# Extra field
(True, True, "// VTABLE: HELLO 0x1234 Extra"),
# Extra with spaces
(True, True, "// VTABLE: HELLO 0x1234 Whatever<SubClass *>"),
# Extra, no space (if the first non-hex character is not in [a-f])
(True, False, "// VTABLE: HELLO 0x1234Hello"),
# Extra, many spaces
(True, False, "// VTABLE: HELLO 0x1234 Hello"),
] ]
@ -174,3 +182,27 @@ def test_get_variable_name(line: str, name: str):
@pytest.mark.parametrize("line, string", string_match_cases) @pytest.mark.parametrize("line, string", string_match_cases)
def test_get_string_contents(line: str, string: str): def test_get_string_contents(line: str, string: str):
assert get_string_contents(line) == string assert get_string_contents(line) == string
def test_marker_extra_spaces():
"""The extra field can contain spaces"""
marker = match_marker("// VTABLE: TEST 0x1234 S p a c e s")
assert marker.extra == "S p a c e s"
# Trailing spaces removed
marker = match_marker("// VTABLE: TEST 0x8888 spaces ")
assert marker.extra == "spaces"
# Trailing newline removed if present
marker = match_marker("// VTABLE: TEST 0x5555 newline\n")
assert marker.extra == "newline"
def test_marker_trailing_spaces():
"""Should ignore trailing spaces. (Invalid extra field)
Offset field not truncated, extra field set to None."""
marker = match_marker("// VTABLE: TEST 0x1234 ")
assert marker is not None
assert marker.offset == 0x1234
assert marker.extra is None

View file

@ -5,3 +5,4 @@ colorama
isledecomp isledecomp
pystache pystache
pyyaml pyyaml
git+https://github.com/wbenny/pydemangler.git

View file

@ -89,6 +89,21 @@ def main():
print_summary(vtable_count, problem_count) print_summary(vtable_count, problem_count)
# Now compare adjuster thunk functions, if there are any.
# These matches are generated by the compare engine.
# They should always match 100%. If not, there is a problem
# with the inheritance or an overriden function.
for fun_match in engine.get_functions():
if "`vtordisp" not in fun_match.name:
continue
diff = engine.compare_address(fun_match.orig_addr)
if diff.ratio < 1.0:
problem_count += 1
print(
f"Problem with adjuster thunk {fun_match.name} (0x{fun_match.orig_addr:x} / 0x{fun_match.recomp_addr:x})"
)
return 1 if problem_count > 0 else 0 return 1 if problem_count > 0 else 0