Detect when we exceed original function size (#736)

This commit is contained in:
MS 2024-03-26 21:01:37 -04:00 committed by GitHub
parent 32bc6c4264
commit 064feab51a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 26 additions and 1 deletions

View file

@ -384,7 +384,16 @@ def _find_vtordisp(self):
self._db.set_function_pair(orig_addr, recomp_addr) self._db.set_function_pair(orig_addr, recomp_addr)
def _compare_function(self, match: MatchInfo) -> DiffReport: def _compare_function(self, match: MatchInfo) -> DiffReport:
orig_raw = self.orig_bin.read(match.orig_addr, match.size) # Detect when the recomp function size would cause us to read
# enough bytes from the original function that we cross into
# the next annotated function.
next_orig = self._db.get_next_orig_addr(match.orig_addr)
if next_orig is not None:
orig_size = min(next_orig - match.orig_addr, match.size)
else:
orig_size = match.size
orig_raw = self.orig_bin.read(match.orig_addr, orig_size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size) recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
# It's unlikely that a function other than an adjuster thunk would # It's unlikely that a function other than an adjuster thunk would

View file

@ -73,6 +73,7 @@ def matchinfo_factory(_, row):
class CompareDb: class CompareDb:
# pylint: disable=too-many-public-methods
def __init__(self): def __init__(self):
self._db = sqlite3.connect(":memory:") self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL) self._db.executescript(_SETUP_SQL)
@ -348,6 +349,21 @@ def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool:
return self.set_pair(addr, recomp_addr, compare_type) return self.set_pair(addr, recomp_addr, compare_type)
def get_next_orig_addr(self, addr: int) -> Optional[int]:
"""Return the original address (matched or not) that follows
the one given. If our recomp function size would cause us to read
too many bytes for the original function, we can adjust it."""
result = self._db.execute(
"""SELECT orig_addr
FROM `symbols`
WHERE orig_addr > ?
ORDER BY orig_addr
LIMIT 1""",
(addr,),
).fetchone()
return result[0] if result is not None else None
def match_function(self, addr: int, name: str) -> bool: def match_function(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.FUNCTION, addr, name) did_match = self._match_on(SymbolType.FUNCTION, addr, name)
if not did_match: if not did_match: