More effective match strategies (#804)

* More effective match strategies

* Basic check on instruction relocation

* More targeted check for relocation
This commit is contained in:
MS 2024-04-14 17:08:42 -04:00 committed by GitHub
parent 540bcc61ad
commit c8840117be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 213 additions and 18 deletions

View file

@ -1,6 +1,9 @@
from difflib import SequenceMatcher
from typing import List
import re
from typing import List, Tuple, Set
DiffOpcode = Tuple[str, int, int, int, int]
REG_FIND = re.compile(r"(?: |\[)(e?[a-d]x|e?[s,d]i|[a-d][l,h]|e?[b,s]p)")
ALLOWED_JUMP_SWAPS = (
("ja", "jb"),
@ -69,8 +72,8 @@ def patch_jump(a: str, b: str) -> str:
def patch_cmp_swaps(
sm: SequenceMatcher, orig_asm: List[str], recomp_asm: List[str]
) -> bool:
codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str]
) -> Set[int]:
"""Can we resolve the diffs between orig and recomp by patching
swapped cmp instructions?
For example:
@ -81,12 +84,7 @@ def patch_cmp_swaps(
ja .label jb .label
"""
# Copy the instructions so we can patch
# TODO: If we change our strategy to allow multiple rounds of patching,
# we should modify the recomp array directly.
new_asm = recomp_asm[::]
codes = sm.get_opcodes()
fixed_lines = set()
for code, i1, i2, j1, j2 in codes:
# To save us the trouble of finding "compatible" cmp instructions
@ -98,9 +96,207 @@ def patch_cmp_swaps(
for i, j in zip(range(i1, i2), range(j1, j2)):
if can_cmp_swap(orig_asm[i : i + 2], recomp_asm[j : j + 2]):
# Patch cmp
new_asm[j] = orig_asm[i]
fixed_lines.add(j)
# Patch the jump if necessary
new_asm[j + 1] = patch_jump(orig_asm[i + 1], recomp_asm[j + 1])
patched = patch_jump(orig_asm[i + 1], recomp_asm[j + 1])
# We only register a fix if it actually matches
if orig_asm[i + 1] == patched:
fixed_lines.add(j + 1)
return orig_asm == new_asm
return fixed_lines
def effective_match_possible(orig_asm: List[str], recomp_asm: List[str]) -> bool:
# We can only declare an effective match based on the text
# so you need the same amount of "stuff" in each
if len(orig_asm) != len(recomp_asm):
return False
# mnemonic_orig = [inst.partition(" ")[0] for inst in orig_asm]
# mnemonic_recomp = [inst.partition(" ")[0] for inst in recomp_asm]
# Cannot change mnemonics. Must be same starting list
# TODO: Fine idea but this will exclude jump swaps for cmp operand order
# if sorted(mnemonic_orig) != sorted(mnemonic_recomp):
# return False
return True
def find_regs_used(inst: str) -> List[str]:
return REG_FIND.findall(inst)
def find_regs_changed(a: str, b: str) -> List[Tuple[str, str]]:
"""For instructions a, b, return the pairs of registers that were used.
This is not a very precise way to compare the instructions, so it depends
on the input being two instructions that would match *except* for
the register choice."""
return zip(REG_FIND.findall(a), REG_FIND.findall(b))
def bad_register_swaps(
swaps: Set[int], orig_asm: List[str], recomp_asm: List[str]
) -> Set[int]:
"""The list of recomp indices in `swaps` tells which instructions are
a match for orig except for the registers used. From that list, check
whether a register swap should not be allowed.
For now, this means checking for `push` instructions where the register
was not used in any other register swaps on previous instructions."""
rejects = set()
# Foreach `push` instruction where we have excused the diff
pushes = [j for j in swaps if recomp_asm[j].startswith("push")]
for j in pushes:
okay = False
# Get the operands in each
reg = (orig_asm[j].partition(" ")[2], recomp_asm[j].partition(" ")[2])
# If this isn't a register at all, ignore it
try:
int(reg[0], 16)
continue
except ValueError:
pass
# For every other excused diff that is *not* a push:
# Assumes same index in orig as in recomp, but so does our naive match
for k in swaps.difference(pushes):
changed_regs = find_regs_changed(orig_asm[k], recomp_asm[k])
if reg in changed_regs or reg[::-1] in changed_regs:
okay = True
break
if not okay:
rejects.add(j)
return rejects
# Instructions that result in a change to the first operand
MODIFIER_INSTRUCTIONS = ("adc", "add", "lea", "mov", "neg", "sbb", "sub", "pop", "xor")
def instruction_alters_regs(inst: str, regs: Set[str]) -> bool:
(mnemonic, _, op_str) = inst.partition(" ")
(first_operand, _, __) = op_str.partition(", ")
return (mnemonic in MODIFIER_INSTRUCTIONS and first_operand in regs) or (
mnemonic == "call" and "eax" in regs
)
def relocate_instructions(
codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str]
) -> Set[int]:
"""Collect the list of instructions deleted from orig and inserted
into recomp, according to the diff opcodes. Using this list, match up
any pairs of instructions that we assume to be relocated and return
the indices in recomp where this has occurred.
For now, we are checking only for an exact match on the instruction.
We are not checking whether the given instruction can be moved from
point A to B. (i.e. does this set a register that is used by the
instructions between A and B?)"""
deletes = {
i for code, i1, i2, _, __ in codes for i in range(i1, i2) if code == "delete"
}
inserts = [
j for code, _, __, j1, j2 in codes for j in range(j1, j2) if code == "insert"
]
relocated = set()
for j in inserts:
line = recomp_asm[j]
recomp_regs_used = set(find_regs_used(line))
for i in deletes:
# Check for exact match.
# TODO: This will grab the first instruction that matches.
# We should probably use the nearest index instead, if it matters
if orig_asm[i] == line:
# To account for a move in either direction
reloc_start = min(i, j)
reloc_end = max(i, j)
if not any(
instruction_alters_regs(orig_asm[k], recomp_regs_used)
for k in range(reloc_start, reloc_end)
):
relocated.add(j)
deletes.remove(i)
break
return relocated
DWORD_REGS = ("eax", "ebx", "ecx", "edx", "esi", "edi", "ebp", "esp")
WORD_REGS = ("ax", "bx", "cx", "dx", "si", "di", "bp", "sp")
BYTE_REGS = ("ah", "al", "bh", "bl", "ch", "cl", "dh", "dl")
def naive_register_replacement(orig_asm: List[str], recomp_asm: List[str]) -> Set[int]:
"""Replace all registers of the same size with a placeholder string.
After doing that, compare orig and recomp again.
Return indices from recomp that are now equal to the same index in orig.
This requires orig and recomp to have the same number of instructions,
but this is already a requirement for effective match."""
orig_raw = "\n".join(orig_asm)
recomp_raw = "\n".join(recomp_asm)
# TODO: hardly the most elegant way to do this.
for rdw in DWORD_REGS:
orig_raw = orig_raw.replace(rdw, "~reg4")
recomp_raw = recomp_raw.replace(rdw, "~reg4")
for rw in WORD_REGS:
orig_raw = orig_raw.replace(rw, "~reg2")
recomp_raw = recomp_raw.replace(rw, "~reg2")
for rb in BYTE_REGS:
orig_raw = orig_raw.replace(rb, "~reg1")
recomp_raw = recomp_raw.replace(rb, "~reg1")
orig_scrubbed = orig_raw.split("\n")
recomp_scrubbed = recomp_raw.split("\n")
return {
j for j in range(len(recomp_scrubbed)) if orig_scrubbed[j] == recomp_scrubbed[j]
}
def find_effective_match(
codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str]
) -> bool:
"""Check whether the two sequences of instructions are an effective match.
Meaning: do they differ only by instruction order or register selection?"""
if not effective_match_possible(orig_asm, recomp_asm):
return False
already_equal = {
j for code, _, __, j1, j2 in codes for j in range(j1, j2) if code == "equal"
}
# We need to come up with some answer for each of these lines
recomp_lines_disputed = {
j
for code, _, __, j1, j2 in codes
for j in range(j1, j2)
if code in ("insert", "replace")
}
cmp_swaps = patch_cmp_swaps(codes, orig_asm, recomp_asm)
# This naive result includes lines that already match, so remove those
naive_swaps = naive_register_replacement(orig_asm, recomp_asm).difference(
already_equal
)
relocates = relocate_instructions(codes, orig_asm, recomp_asm)
bad_swaps = bad_register_swaps(naive_swaps, orig_asm, recomp_asm)
corrections = set().union(
naive_swaps.difference(bad_swaps),
cmp_swaps,
relocates,
)
return corrections.issuperset(recomp_lines_disputed)

View file

@ -10,8 +10,8 @@
from isledecomp.parser import DecompCodebase
from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType
from isledecomp.compare.asm import ParseAsm, can_resolve_register_differences
from isledecomp.compare.asm.fixes import patch_cmp_swaps
from isledecomp.compare.asm import ParseAsm
from isledecomp.compare.asm.fixes import find_effective_match
from .db import CompareDb, MatchInfo
from .diff import combined_diff
from .lines import LinesDb
@ -493,9 +493,8 @@ def recomp_lookup(addr: int) -> Optional[str]:
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
is_effective_match = patch_cmp_swaps(
diff, orig_asm, recomp_asm
) or can_resolve_register_differences(orig_asm, recomp_asm)
codes = diff.get_opcodes()
is_effective_match = find_effective_match(codes, orig_asm, recomp_asm)
unified_diff = combined_diff(
diff, orig_combined, recomp_combined, context_size=10
)