From c8840117be3c127d293217c7560949ea35a32667 Mon Sep 17 00:00:00 2001 From: MS Date: Sun, 14 Apr 2024 17:08:42 -0400 Subject: [PATCH] More effective match strategies (#804) * More effective match strategies * Basic check on instruction relocation * More targeted check for relocation --- .../isledecomp/compare/asm/fixes.py | 222 +++++++++++++++++- tools/isledecomp/isledecomp/compare/core.py | 9 +- 2 files changed, 213 insertions(+), 18 deletions(-) diff --git a/tools/isledecomp/isledecomp/compare/asm/fixes.py b/tools/isledecomp/isledecomp/compare/asm/fixes.py index b1bbd2ad..bca22681 100644 --- a/tools/isledecomp/isledecomp/compare/asm/fixes.py +++ b/tools/isledecomp/isledecomp/compare/asm/fixes.py @@ -1,6 +1,9 @@ -from difflib import SequenceMatcher -from typing import List +import re +from typing import List, Tuple, Set +DiffOpcode = Tuple[str, int, int, int, int] + +REG_FIND = re.compile(r"(?: |\[)(e?[a-d]x|e?[s,d]i|[a-d][l,h]|e?[b,s]p)") ALLOWED_JUMP_SWAPS = ( ("ja", "jb"), @@ -69,8 +72,8 @@ def patch_jump(a: str, b: str) -> str: def patch_cmp_swaps( - sm: SequenceMatcher, orig_asm: List[str], recomp_asm: List[str] -) -> bool: + codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str] +) -> Set[int]: """Can we resolve the diffs between orig and recomp by patching swapped cmp instructions? For example: @@ -81,12 +84,7 @@ def patch_cmp_swaps( ja .label jb .label """ - # Copy the instructions so we can patch - # TODO: If we change our strategy to allow multiple rounds of patching, - # we should modify the recomp array directly. - new_asm = recomp_asm[::] - - codes = sm.get_opcodes() + fixed_lines = set() for code, i1, i2, j1, j2 in codes: # To save us the trouble of finding "compatible" cmp instructions @@ -98,9 +96,207 @@ def patch_cmp_swaps( for i, j in zip(range(i1, i2), range(j1, j2)): if can_cmp_swap(orig_asm[i : i + 2], recomp_asm[j : j + 2]): # Patch cmp - new_asm[j] = orig_asm[i] + fixed_lines.add(j) # Patch the jump if necessary - new_asm[j + 1] = patch_jump(orig_asm[i + 1], recomp_asm[j + 1]) + patched = patch_jump(orig_asm[i + 1], recomp_asm[j + 1]) + # We only register a fix if it actually matches + if orig_asm[i + 1] == patched: + fixed_lines.add(j + 1) - return orig_asm == new_asm + return fixed_lines + + +def effective_match_possible(orig_asm: List[str], recomp_asm: List[str]) -> bool: + # We can only declare an effective match based on the text + # so you need the same amount of "stuff" in each + if len(orig_asm) != len(recomp_asm): + return False + + # mnemonic_orig = [inst.partition(" ")[0] for inst in orig_asm] + # mnemonic_recomp = [inst.partition(" ")[0] for inst in recomp_asm] + + # Cannot change mnemonics. Must be same starting list + # TODO: Fine idea but this will exclude jump swaps for cmp operand order + # if sorted(mnemonic_orig) != sorted(mnemonic_recomp): + # return False + + return True + + +def find_regs_used(inst: str) -> List[str]: + return REG_FIND.findall(inst) + + +def find_regs_changed(a: str, b: str) -> List[Tuple[str, str]]: + """For instructions a, b, return the pairs of registers that were used. + This is not a very precise way to compare the instructions, so it depends + on the input being two instructions that would match *except* for + the register choice.""" + return zip(REG_FIND.findall(a), REG_FIND.findall(b)) + + +def bad_register_swaps( + swaps: Set[int], orig_asm: List[str], recomp_asm: List[str] +) -> Set[int]: + """The list of recomp indices in `swaps` tells which instructions are + a match for orig except for the registers used. From that list, check + whether a register swap should not be allowed. + For now, this means checking for `push` instructions where the register + was not used in any other register swaps on previous instructions.""" + rejects = set() + + # Foreach `push` instruction where we have excused the diff + pushes = [j for j in swaps if recomp_asm[j].startswith("push")] + + for j in pushes: + okay = False + # Get the operands in each + reg = (orig_asm[j].partition(" ")[2], recomp_asm[j].partition(" ")[2]) + # If this isn't a register at all, ignore it + try: + int(reg[0], 16) + continue + except ValueError: + pass + + # For every other excused diff that is *not* a push: + # Assumes same index in orig as in recomp, but so does our naive match + for k in swaps.difference(pushes): + changed_regs = find_regs_changed(orig_asm[k], recomp_asm[k]) + if reg in changed_regs or reg[::-1] in changed_regs: + okay = True + break + + if not okay: + rejects.add(j) + + return rejects + + +# Instructions that result in a change to the first operand +MODIFIER_INSTRUCTIONS = ("adc", "add", "lea", "mov", "neg", "sbb", "sub", "pop", "xor") + + +def instruction_alters_regs(inst: str, regs: Set[str]) -> bool: + (mnemonic, _, op_str) = inst.partition(" ") + (first_operand, _, __) = op_str.partition(", ") + + return (mnemonic in MODIFIER_INSTRUCTIONS and first_operand in regs) or ( + mnemonic == "call" and "eax" in regs + ) + + +def relocate_instructions( + codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str] +) -> Set[int]: + """Collect the list of instructions deleted from orig and inserted + into recomp, according to the diff opcodes. Using this list, match up + any pairs of instructions that we assume to be relocated and return + the indices in recomp where this has occurred. + For now, we are checking only for an exact match on the instruction. + We are not checking whether the given instruction can be moved from + point A to B. (i.e. does this set a register that is used by the + instructions between A and B?)""" + deletes = { + i for code, i1, i2, _, __ in codes for i in range(i1, i2) if code == "delete" + } + inserts = [ + j for code, _, __, j1, j2 in codes for j in range(j1, j2) if code == "insert" + ] + + relocated = set() + + for j in inserts: + line = recomp_asm[j] + recomp_regs_used = set(find_regs_used(line)) + for i in deletes: + # Check for exact match. + # TODO: This will grab the first instruction that matches. + # We should probably use the nearest index instead, if it matters + if orig_asm[i] == line: + # To account for a move in either direction + reloc_start = min(i, j) + reloc_end = max(i, j) + if not any( + instruction_alters_regs(orig_asm[k], recomp_regs_used) + for k in range(reloc_start, reloc_end) + ): + relocated.add(j) + deletes.remove(i) + break + + return relocated + + +DWORD_REGS = ("eax", "ebx", "ecx", "edx", "esi", "edi", "ebp", "esp") +WORD_REGS = ("ax", "bx", "cx", "dx", "si", "di", "bp", "sp") +BYTE_REGS = ("ah", "al", "bh", "bl", "ch", "cl", "dh", "dl") + + +def naive_register_replacement(orig_asm: List[str], recomp_asm: List[str]) -> Set[int]: + """Replace all registers of the same size with a placeholder string. + After doing that, compare orig and recomp again. + Return indices from recomp that are now equal to the same index in orig. + This requires orig and recomp to have the same number of instructions, + but this is already a requirement for effective match.""" + orig_raw = "\n".join(orig_asm) + recomp_raw = "\n".join(recomp_asm) + + # TODO: hardly the most elegant way to do this. + for rdw in DWORD_REGS: + orig_raw = orig_raw.replace(rdw, "~reg4") + recomp_raw = recomp_raw.replace(rdw, "~reg4") + + for rw in WORD_REGS: + orig_raw = orig_raw.replace(rw, "~reg2") + recomp_raw = recomp_raw.replace(rw, "~reg2") + + for rb in BYTE_REGS: + orig_raw = orig_raw.replace(rb, "~reg1") + recomp_raw = recomp_raw.replace(rb, "~reg1") + + orig_scrubbed = orig_raw.split("\n") + recomp_scrubbed = recomp_raw.split("\n") + + return { + j for j in range(len(recomp_scrubbed)) if orig_scrubbed[j] == recomp_scrubbed[j] + } + + +def find_effective_match( + codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str] +) -> bool: + """Check whether the two sequences of instructions are an effective match. + Meaning: do they differ only by instruction order or register selection?""" + if not effective_match_possible(orig_asm, recomp_asm): + return False + + already_equal = { + j for code, _, __, j1, j2 in codes for j in range(j1, j2) if code == "equal" + } + + # We need to come up with some answer for each of these lines + recomp_lines_disputed = { + j + for code, _, __, j1, j2 in codes + for j in range(j1, j2) + if code in ("insert", "replace") + } + + cmp_swaps = patch_cmp_swaps(codes, orig_asm, recomp_asm) + # This naive result includes lines that already match, so remove those + naive_swaps = naive_register_replacement(orig_asm, recomp_asm).difference( + already_equal + ) + relocates = relocate_instructions(codes, orig_asm, recomp_asm) + + bad_swaps = bad_register_swaps(naive_swaps, orig_asm, recomp_asm) + + corrections = set().union( + naive_swaps.difference(bad_swaps), + cmp_swaps, + relocates, + ) + + return corrections.issuperset(recomp_lines_disputed) diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py index d602af69..7b5ec87d 100644 --- a/tools/isledecomp/isledecomp/compare/core.py +++ b/tools/isledecomp/isledecomp/compare/core.py @@ -10,8 +10,8 @@ from isledecomp.parser import DecompCodebase from isledecomp.dir import walk_source_dir from isledecomp.types import SymbolType -from isledecomp.compare.asm import ParseAsm, can_resolve_register_differences -from isledecomp.compare.asm.fixes import patch_cmp_swaps +from isledecomp.compare.asm import ParseAsm +from isledecomp.compare.asm.fixes import find_effective_match from .db import CompareDb, MatchInfo from .diff import combined_diff from .lines import LinesDb @@ -493,9 +493,8 @@ def recomp_lookup(addr: int) -> Optional[str]: if ratio != 1.0: # Check whether we can resolve register swaps which are actually # perfect matches modulo compiler entropy. - is_effective_match = patch_cmp_swaps( - diff, orig_asm, recomp_asm - ) or can_resolve_register_differences(orig_asm, recomp_asm) + codes = diff.get_opcodes() + is_effective_match = find_effective_match(codes, orig_asm, recomp_asm) unified_diff = combined_diff( diff, orig_combined, recomp_combined, context_size=10 )