Pointer substitution can use offset into variable (#841)

This commit is contained in:
MS 2024-04-23 17:06:43 -04:00 committed by GitHub
parent 9025d5ed06
commit 41be78ed1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 91 additions and 19 deletions

View file

@ -56,7 +56,7 @@ class ParseAsm:
def __init__( def __init__(
self, self,
relocate_lookup: Optional[Callable[[int], bool]] = None, relocate_lookup: Optional[Callable[[int], bool]] = None,
name_lookup: Optional[Callable[[int], str]] = None, name_lookup: Optional[Callable[[int, bool], str]] = None,
bin_lookup: Optional[Callable[[int, int], Optional[bytes]]] = None, bin_lookup: Optional[Callable[[int, int], Optional[bytes]]] = None,
) -> None: ) -> None:
self.relocate_lookup = relocate_lookup self.relocate_lookup = relocate_lookup
@ -86,13 +86,15 @@ def float_replace(self, addr: int, data_size: int) -> Optional[str]:
return None return None
def lookup(self, addr: int, use_cache: bool = True) -> Optional[str]: def lookup(
self, addr: int, use_cache: bool = True, exact: bool = False
) -> Optional[str]:
"""Return a replacement name for this address if we find one.""" """Return a replacement name for this address if we find one."""
if use_cache and (cached := self.replacements.get(addr, None)) is not None: if use_cache and (cached := self.replacements.get(addr, None)) is not None:
return cached return cached
if callable(self.name_lookup): if callable(self.name_lookup):
if (name := self.name_lookup(addr)) is not None: if (name := self.name_lookup(addr, exact)) is not None:
if use_cache: if use_cache:
self.replacements[addr] = name self.replacements[addr] = name
@ -210,7 +212,7 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# If we have a name for this address, use it. If not, # If we have a name for this address, use it. If not,
# do not create a new placeholder. We will instead # do not create a new placeholder. We will instead
# fall through to generic jump handling below. # fall through to generic jump handling below.
potential_name = self.lookup(op_str_address) potential_name = self.lookup(op_str_address, exact=True)
if potential_name is not None: if potential_name is not None:
return (inst.mnemonic, potential_name) return (inst.mnemonic, potential_name)

View file

@ -2,6 +2,7 @@
import logging import logging
import difflib import difflib
import struct import struct
import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional from typing import Callable, Iterable, List, Optional
from isledecomp.bin import Bin as IsleBin, InvalidVirtualAddressError from isledecomp.bin import Bin as IsleBin, InvalidVirtualAddressError
@ -71,6 +72,9 @@ def __init__(
self.recomp_bin = recomp_bin self.recomp_bin = recomp_bin
self.pdb_file = pdb_file self.pdb_file = pdb_file
self.code_dir = code_dir self.code_dir = code_dir
# Controls whether we dump the asm output to a file
self.debug: bool = False
self.runid: str = uuid.uuid4().hex[:8]
self._lines_db = LinesDb(code_dir) self._lines_db = LinesDb(code_dir)
self._db = CompareDb() self._db = CompareDb()
@ -452,6 +456,16 @@ def _find_vtordisp(self):
) )
self._db.set_function_pair(orig_addr, recomp_addr) self._db.set_function_pair(orig_addr, recomp_addr)
def _dump_asm(self, orig_combined, recomp_combined):
"""Append the provided assembly output to the debug files"""
with open(f"orig-{self.runid}.txt", "a", encoding="utf-8") as f:
for addr, line in orig_combined:
f.write(f"{addr}: {line}\n")
with open(f"recomp-{self.runid}.txt", "a", encoding="utf-8") as f:
for addr, line in recomp_combined:
f.write(f"{addr}: {line}\n")
def _compare_function(self, match: MatchInfo) -> DiffReport: def _compare_function(self, match: MatchInfo) -> DiffReport:
# Detect when the recomp function size would cause us to read # Detect when the recomp function size would cause us to read
# enough bytes from the original function that we cross into # enough bytes from the original function that we cross into
@ -478,19 +492,33 @@ def _compare_function(self, match: MatchInfo) -> DiffReport:
except IndexError: except IndexError:
pass pass
def orig_lookup(addr: int) -> Optional[str]: def orig_lookup(addr: int, exact: bool) -> Optional[str]:
m = self._db.get_by_orig(addr) m = self._db.get_by_orig(addr, exact)
if m is None: if m is None:
return None return None
return m.match_name() if m.orig_addr == addr:
return m.match_name()
def recomp_lookup(addr: int) -> Optional[str]: offset = addr - m.orig_addr
m = self._db.get_by_recomp(addr) if m.compare_type != SymbolType.DATA or offset >= m.size:
return None
return m.offset_name(offset)
def recomp_lookup(addr: int, exact: bool) -> Optional[str]:
m = self._db.get_by_recomp(addr, exact)
if m is None: if m is None:
return None return None
return m.match_name() if m.recomp_addr == addr:
return m.match_name()
offset = addr - m.recomp_addr
if m.compare_type != SymbolType.DATA or offset >= m.size:
return None
return m.offset_name(offset)
orig_should_replace = create_reloc_lookup(self.orig_bin) orig_should_replace = create_reloc_lookup(self.orig_bin)
recomp_should_replace = create_reloc_lookup(self.recomp_bin) recomp_should_replace = create_reloc_lookup(self.recomp_bin)
@ -512,6 +540,9 @@ def recomp_lookup(addr: int) -> Optional[str]:
orig_combined = orig_parse.parse_asm(orig_raw, match.orig_addr) orig_combined = orig_parse.parse_asm(orig_raw, match.orig_addr)
recomp_combined = recomp_parse.parse_asm(recomp_raw, match.recomp_addr) recomp_combined = recomp_parse.parse_asm(recomp_raw, match.recomp_addr)
if self.debug:
self._dump_asm(orig_combined, recomp_combined)
# Detach addresses from asm lines for the text diff. # Detach addresses from asm lines for the text diff.
orig_asm = [x[1] for x in orig_combined] orig_asm = [x[1] for x in orig_combined]
recomp_asm = [x[1] for x in recomp_combined] recomp_asm = [x[1] for x in recomp_combined]

View file

@ -53,7 +53,7 @@ def __init__(
self.name = name self.name = name
self.size = size self.size = size
def match_name(self) -> str: def match_name(self) -> Optional[str]:
"""Combination of the name and compare type. """Combination of the name and compare type.
Intended for name substitution in the diff. If there is a diff, Intended for name substitution in the diff. If there is a diff,
it will be more obvious what this symbol indicates.""" it will be more obvious what this symbol indicates."""
@ -64,6 +64,12 @@ def match_name(self) -> str:
name = repr(self.name) if ctype == "STRING" else self.name name = repr(self.name) if ctype == "STRING" else self.name
return f"{name} ({ctype})" return f"{name} ({ctype})"
def offset_name(self, ofs: int) -> Optional[str]:
if self.name is None:
return None
return f"{self.name}+{ofs} (OFFSET)"
def matchinfo_factory(_, row): def matchinfo_factory(_, row):
return MatchInfo(*row) return MatchInfo(*row)
@ -135,7 +141,32 @@ def get_one_match(self, addr: int) -> Optional[MatchInfo]:
cur.row_factory = matchinfo_factory cur.row_factory = matchinfo_factory
return cur.fetchone() return cur.fetchone()
def get_by_orig(self, addr: int) -> Optional[MatchInfo]: def _get_closest_orig(self, addr: int) -> Optional[int]:
value = self._db.execute(
"""SELECT max(orig_addr) FROM `symbols`
WHERE ? >= orig_addr
LIMIT 1
""",
(addr,),
).fetchone()
return value[0] if value is not None else None
def _get_closest_recomp(self, addr: int) -> Optional[int]:
value = self._db.execute(
"""SELECT max(recomp_addr) FROM `symbols`
WHERE ? >= recomp_addr
LIMIT 1
""",
(addr,),
).fetchone()
return value[0] if value is not None else None
def get_by_orig(self, addr: int, exact: bool = True) -> Optional[MatchInfo]:
if not exact and not self._orig_used(addr):
addr = self._get_closest_orig(addr)
if addr is None:
return None
cur = self._db.execute( cur = self._db.execute(
"""SELECT * FROM `match_info` """SELECT * FROM `match_info`
WHERE orig_addr = ? WHERE orig_addr = ?
@ -145,7 +176,12 @@ def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
cur.row_factory = matchinfo_factory cur.row_factory = matchinfo_factory
return cur.fetchone() return cur.fetchone()
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]: def get_by_recomp(self, addr: int, exact: bool = True) -> Optional[MatchInfo]:
if not exact and not self._recomp_used(addr):
addr = self._get_closest_recomp(addr)
if addr is None:
return None
cur = self._db.execute( cur = self._db.execute(
"""SELECT * FROM `match_info` """SELECT * FROM `match_info`
WHERE recomp_addr = ? WHERE recomp_addr = ?

View file

@ -113,7 +113,7 @@ def relocate_lookup(addr: int) -> bool:
def test_name_replace(start, end): def test_name_replace(start, end):
"""Make sure the name lookup function is called if present""" """Make sure the name lookup function is called if present"""
def substitute(_: int) -> str: def substitute(_: int, __: bool) -> str:
return "_substitute_" return "_substitute_"
p = ParseAsm(name_lookup=substitute) p = ParseAsm(name_lookup=substitute)
@ -137,7 +137,7 @@ def test_replacement_numbering():
"""If we can use the name lookup for the first address but not the second, """If we can use the name lookup for the first address but not the second,
the second replacement should be <OFFSET2> not <OFFSET1>.""" the second replacement should be <OFFSET2> not <OFFSET1>."""
def substitute_1234(addr: int) -> Optional[str]: def substitute_1234(addr: int, _: bool) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234) p = ParseAsm(name_lookup=substitute_1234)
@ -171,7 +171,7 @@ def test_jump_to_function():
assume this is the case for all jumps. Only replace the jump with a name assume this is the case for all jumps. Only replace the jump with a name
if we can find it using our lookup.""" if we can find it using our lookup."""
def substitute_1234(addr: int) -> Optional[str]: def substitute_1234(addr: int, _: bool) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234) p = ParseAsm(name_lookup=substitute_1234)
@ -212,7 +212,7 @@ def test_float_variable():
"""If there is a variable at the address referenced by a float instruction, """If there is a variable at the address referenced by a float instruction,
use the name instead of calling into the float replacement handler.""" use the name instead of calling into the float replacement handler."""
def name_lookup(addr: int) -> Optional[str]: def name_lookup(addr: int, _: bool) -> Optional[str]:
return "g_myFloatVariable" if addr == 0x1234 else None return "g_myFloatVariable" if addr == 0x1234 else None
p = ParseAsm(name_lookup=name_lookup) p = ParseAsm(name_lookup=name_lookup)
@ -234,7 +234,7 @@ def relocate_lookup(addr: int) -> bool:
return addr in (0x1234, 0x5555) return addr in (0x1234, 0x5555)
# Only 0x5555 is a "known" address # Only 0x5555 is a "known" address
def name_lookup(addr: int) -> Optional[str]: def name_lookup(addr: int, _: bool) -> Optional[str]:
return "hello" if addr == 0x5555 else None return "hello" if addr == 0x5555 else None
p = ParseAsm(relocate_lookup=relocate_lookup, name_lookup=name_lookup) p = ParseAsm(relocate_lookup=relocate_lookup, name_lookup=name_lookup)
@ -263,7 +263,7 @@ def test_absolute_indirect():
we have it, but there are some circumstances where we want to replace we have it, but there are some circumstances where we want to replace
with the pointer's name (i.e. an import function).""" with the pointer's name (i.e. an import function)."""
def name_lookup(addr: int) -> Optional[str]: def name_lookup(addr: int, _: bool) -> Optional[str]:
return { return {
0x1234: "Hello", 0x1234: "Hello",
0x4321: "xyz", 0x4321: "xyz",

View file

@ -241,6 +241,9 @@ def main():
isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir) isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir)
if args.loglevel == logging.DEBUG:
isle_compare.debug = True
print() print()
### Compare one or none. ### Compare one or none.