Pointer substitution can use offset into variable (#841)

This commit is contained in:
MS 2024-04-23 17:06:43 -04:00 committed by GitHub
parent 9025d5ed06
commit 41be78ed1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 91 additions and 19 deletions

View file

@ -56,7 +56,7 @@ class ParseAsm:
def __init__(
self,
relocate_lookup: Optional[Callable[[int], bool]] = None,
name_lookup: Optional[Callable[[int], str]] = None,
name_lookup: Optional[Callable[[int, bool], str]] = None,
bin_lookup: Optional[Callable[[int, int], Optional[bytes]]] = None,
) -> None:
self.relocate_lookup = relocate_lookup
@ -86,13 +86,15 @@ def float_replace(self, addr: int, data_size: int) -> Optional[str]:
return None
def lookup(self, addr: int, use_cache: bool = True) -> Optional[str]:
def lookup(
self, addr: int, use_cache: bool = True, exact: bool = False
) -> Optional[str]:
"""Return a replacement name for this address if we find one."""
if use_cache and (cached := self.replacements.get(addr, None)) is not None:
return cached
if callable(self.name_lookup):
if (name := self.name_lookup(addr)) is not None:
if (name := self.name_lookup(addr, exact)) is not None:
if use_cache:
self.replacements[addr] = name
@ -210,7 +212,7 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# If we have a name for this address, use it. If not,
# do not create a new placeholder. We will instead
# fall through to generic jump handling below.
potential_name = self.lookup(op_str_address)
potential_name = self.lookup(op_str_address, exact=True)
if potential_name is not None:
return (inst.mnemonic, potential_name)

View file

@ -2,6 +2,7 @@
import logging
import difflib
import struct
import uuid
from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional
from isledecomp.bin import Bin as IsleBin, InvalidVirtualAddressError
@ -71,6 +72,9 @@ def __init__(
self.recomp_bin = recomp_bin
self.pdb_file = pdb_file
self.code_dir = code_dir
# Controls whether we dump the asm output to a file
self.debug: bool = False
self.runid: str = uuid.uuid4().hex[:8]
self._lines_db = LinesDb(code_dir)
self._db = CompareDb()
@ -452,6 +456,16 @@ def _find_vtordisp(self):
)
self._db.set_function_pair(orig_addr, recomp_addr)
def _dump_asm(self, orig_combined, recomp_combined):
"""Append the provided assembly output to the debug files"""
with open(f"orig-{self.runid}.txt", "a", encoding="utf-8") as f:
for addr, line in orig_combined:
f.write(f"{addr}: {line}\n")
with open(f"recomp-{self.runid}.txt", "a", encoding="utf-8") as f:
for addr, line in recomp_combined:
f.write(f"{addr}: {line}\n")
def _compare_function(self, match: MatchInfo) -> DiffReport:
# Detect when the recomp function size would cause us to read
# enough bytes from the original function that we cross into
@ -478,20 +492,34 @@ def _compare_function(self, match: MatchInfo) -> DiffReport:
except IndexError:
pass
def orig_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_orig(addr)
def orig_lookup(addr: int, exact: bool) -> Optional[str]:
m = self._db.get_by_orig(addr, exact)
if m is None:
return None
if m.orig_addr == addr:
return m.match_name()
def recomp_lookup(addr: int) -> Optional[str]:
m = self._db.get_by_recomp(addr)
offset = addr - m.orig_addr
if m.compare_type != SymbolType.DATA or offset >= m.size:
return None
return m.offset_name(offset)
def recomp_lookup(addr: int, exact: bool) -> Optional[str]:
m = self._db.get_by_recomp(addr, exact)
if m is None:
return None
if m.recomp_addr == addr:
return m.match_name()
offset = addr - m.recomp_addr
if m.compare_type != SymbolType.DATA or offset >= m.size:
return None
return m.offset_name(offset)
orig_should_replace = create_reloc_lookup(self.orig_bin)
recomp_should_replace = create_reloc_lookup(self.recomp_bin)
@ -512,6 +540,9 @@ def recomp_lookup(addr: int) -> Optional[str]:
orig_combined = orig_parse.parse_asm(orig_raw, match.orig_addr)
recomp_combined = recomp_parse.parse_asm(recomp_raw, match.recomp_addr)
if self.debug:
self._dump_asm(orig_combined, recomp_combined)
# Detach addresses from asm lines for the text diff.
orig_asm = [x[1] for x in orig_combined]
recomp_asm = [x[1] for x in recomp_combined]

View file

@ -53,7 +53,7 @@ def __init__(
self.name = name
self.size = size
def match_name(self) -> str:
def match_name(self) -> Optional[str]:
"""Combination of the name and compare type.
Intended for name substitution in the diff. If there is a diff,
it will be more obvious what this symbol indicates."""
@ -64,6 +64,12 @@ def match_name(self) -> str:
name = repr(self.name) if ctype == "STRING" else self.name
return f"{name} ({ctype})"
def offset_name(self, ofs: int) -> Optional[str]:
if self.name is None:
return None
return f"{self.name}+{ofs} (OFFSET)"
def matchinfo_factory(_, row):
return MatchInfo(*row)
@ -135,7 +141,32 @@ def get_one_match(self, addr: int) -> Optional[MatchInfo]:
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
def _get_closest_orig(self, addr: int) -> Optional[int]:
value = self._db.execute(
"""SELECT max(orig_addr) FROM `symbols`
WHERE ? >= orig_addr
LIMIT 1
""",
(addr,),
).fetchone()
return value[0] if value is not None else None
def _get_closest_recomp(self, addr: int) -> Optional[int]:
value = self._db.execute(
"""SELECT max(recomp_addr) FROM `symbols`
WHERE ? >= recomp_addr
LIMIT 1
""",
(addr,),
).fetchone()
return value[0] if value is not None else None
def get_by_orig(self, addr: int, exact: bool = True) -> Optional[MatchInfo]:
if not exact and not self._orig_used(addr):
addr = self._get_closest_orig(addr)
if addr is None:
return None
cur = self._db.execute(
"""SELECT * FROM `match_info`
WHERE orig_addr = ?
@ -145,7 +176,12 @@ def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
def get_by_recomp(self, addr: int, exact: bool = True) -> Optional[MatchInfo]:
if not exact and not self._recomp_used(addr):
addr = self._get_closest_recomp(addr)
if addr is None:
return None
cur = self._db.execute(
"""SELECT * FROM `match_info`
WHERE recomp_addr = ?

View file

@ -113,7 +113,7 @@ def relocate_lookup(addr: int) -> bool:
def test_name_replace(start, end):
"""Make sure the name lookup function is called if present"""
def substitute(_: int) -> str:
def substitute(_: int, __: bool) -> str:
return "_substitute_"
p = ParseAsm(name_lookup=substitute)
@ -137,7 +137,7 @@ def test_replacement_numbering():
"""If we can use the name lookup for the first address but not the second,
the second replacement should be <OFFSET2> not <OFFSET1>."""
def substitute_1234(addr: int) -> Optional[str]:
def substitute_1234(addr: int, _: bool) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
@ -171,7 +171,7 @@ def test_jump_to_function():
assume this is the case for all jumps. Only replace the jump with a name
if we can find it using our lookup."""
def substitute_1234(addr: int) -> Optional[str]:
def substitute_1234(addr: int, _: bool) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
@ -212,7 +212,7 @@ def test_float_variable():
"""If there is a variable at the address referenced by a float instruction,
use the name instead of calling into the float replacement handler."""
def name_lookup(addr: int) -> Optional[str]:
def name_lookup(addr: int, _: bool) -> Optional[str]:
return "g_myFloatVariable" if addr == 0x1234 else None
p = ParseAsm(name_lookup=name_lookup)
@ -234,7 +234,7 @@ def relocate_lookup(addr: int) -> bool:
return addr in (0x1234, 0x5555)
# Only 0x5555 is a "known" address
def name_lookup(addr: int) -> Optional[str]:
def name_lookup(addr: int, _: bool) -> Optional[str]:
return "hello" if addr == 0x5555 else None
p = ParseAsm(relocate_lookup=relocate_lookup, name_lookup=name_lookup)
@ -263,7 +263,7 @@ def test_absolute_indirect():
we have it, but there are some circumstances where we want to replace
with the pointer's name (i.e. an import function)."""
def name_lookup(addr: int) -> Optional[str]:
def name_lookup(addr: int, _: bool) -> Optional[str]:
return {
0x1234: "Hello",
0x4321: "xyz",

View file

@ -241,6 +241,9 @@ def main():
isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir)
if args.loglevel == logging.DEBUG:
isle_compare.debug = True
print()
### Compare one or none.