diff --git a/tools/isledecomp/isledecomp/bin.py b/tools/isledecomp/isledecomp/bin.py index b103a143..514f4a8e 100644 --- a/tools/isledecomp/isledecomp/bin.py +++ b/tools/isledecomp/isledecomp/bin.py @@ -112,6 +112,7 @@ def __init__(self, filename: str, find_str: bool = False) -> None: self._relocated_addrs = set() self.imports = [] self.thunks = [] + self.exports: List[Tuple[int, str]] = [] def __enter__(self): logger.debug("Bin %s Enter", self.filename) @@ -137,6 +138,11 @@ def __enter__(self): (entry,) = struct.unpack(" Section: section = next( filter(lambda section: section.match_name(name), self.sections), diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py index ca48cb63..3d2246e4 100644 --- a/tools/isledecomp/isledecomp/compare/core.py +++ b/tools/isledecomp/isledecomp/compare/core.py @@ -85,6 +85,7 @@ def __init__( self._load_markers() self._find_original_strings() self._match_thunks() + self._match_exports() def _load_cvdump(self): logger.info("Parsing %s ...", self.pdb_file) @@ -166,12 +167,11 @@ def _load_cvdump(self): self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry) def _load_markers(self): - # Guess at module name from PDB file name - # reccmp checks the original binary filename; we could use this too - (module, _) = os.path.splitext(os.path.basename(self.pdb_file)) + # Assume module name is the base filename of the original binary. + (module, _) = os.path.splitext(os.path.basename(self.orig_bin.filename)) codefiles = list(walk_source_dir(self.code_dir)) - codebase = DecompCodebase(codefiles, module) + codebase = DecompCodebase(codefiles, module.upper()) # Match lineref functions first because this is a guaranteed match. # If we have two functions that share the same name, and one is @@ -274,6 +274,17 @@ def _match_thunks(self): # function in the first place. self._db.skip_compare(thunk_from_orig) + def _match_exports(self): + # invert for name lookup + orig_exports = {y: x for (x, y) in self.orig_bin.exports} + + for recomp_addr, export_name in self.recomp_bin.exports: + orig_addr = orig_exports.get(export_name) + if orig_addr is not None and self._db.set_pair_tentative( + orig_addr, recomp_addr + ): + logger.debug("Matched export %s", repr(export_name)) + def _compare_function(self, match: MatchInfo) -> DiffReport: orig_raw = self.orig_bin.read(match.orig_addr, match.size) recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size) diff --git a/tools/isledecomp/isledecomp/compare/db.py b/tools/isledecomp/isledecomp/compare/db.py index 7e0546b8..54ae8081 100644 --- a/tools/isledecomp/isledecomp/compare/db.py +++ b/tools/isledecomp/isledecomp/compare/db.py @@ -86,7 +86,7 @@ def set_recomp_symbol( ): # Ignore collisions here. The same recomp address can have # multiple names (e.g. _strlwr and __strlwr) - if self.recomp_used(addr): + if self._recomp_used(addr): return compare_value = compare_type.value if compare_type is not None else None @@ -166,18 +166,18 @@ def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]: return cur.fetchall() - def orig_used(self, addr: int) -> bool: + def _orig_used(self, addr: int) -> bool: cur = self._db.execute("SELECT 1 FROM symbols WHERE orig_addr = ?", (addr,)) return cur.fetchone() is not None - def recomp_used(self, addr: int) -> bool: + def _recomp_used(self, addr: int) -> bool: cur = self._db.execute("SELECT 1 FROM symbols WHERE recomp_addr = ?", (addr,)) return cur.fetchone() is not None def set_pair( self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None ) -> bool: - if self.orig_used(orig): + if self._orig_used(orig): logger.error("Original address %s not unique!", hex(orig)) return False @@ -189,6 +189,32 @@ def set_pair( return cur.rowcount > 0 + def set_pair_tentative( + self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None + ) -> bool: + """Declare a match for the original and recomp addresses given, but only if: + 1. The original address is not used elsewhere (as with set_pair) + 2. The recomp address has not already been matched + If the compare_type is given, update this also, but only if NULL in the db. + + The purpose here is to set matches found via some automated analysis + but to not overwrite a match provided by the human operator.""" + if self._orig_used(orig): + # Probable and expected situation. Just ignore it. + return False + + compare_value = compare_type.value if compare_type is not None else None + + cur = self._db.execute( + """UPDATE `symbols` + SET orig_addr = ?, compare_type = coalesce(compare_type, ?) + WHERE recomp_addr = ? + AND orig_addr IS NULL""", + (orig, compare_value, recomp), + ) + + return cur.rowcount > 0 + def set_function_pair(self, orig: int, recomp: int) -> bool: """For lineref match or _entry""" return self.set_pair(orig, recomp, SymbolType.FUNCTION) diff --git a/tools/isledecomp/tests/test_islebin.py b/tools/isledecomp/tests/test_islebin.py index ffb3c494..14fb85b4 100644 --- a/tools/isledecomp/tests/test_islebin.py +++ b/tools/isledecomp/tests/test_islebin.py @@ -144,3 +144,9 @@ def test_imports(import_ref: Tuple[str, str, int], binfile: IsleBin): @pytest.mark.parametrize("thunk_ref", THUNKS) def test_thunks(thunk_ref: Tuple[int, int], binfile: IsleBin): assert thunk_ref in binfile.thunks + + +def test_exports(binfile: IsleBin): + assert len(binfile.exports) == 130 + assert (0x1003BFB0, b"??0LegoBackgroundColor@@QAE@PBD0@Z") in binfile.exports + assert (0x10091EE0, b"_DllMain@12") in binfile.exports