Swap cmp operands for effective match (#783)

This commit is contained in:
MS 2024-04-07 16:57:41 -04:00 committed by GitHub
parent 1bfe47357b
commit 70912d16c6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 110 additions and 1 deletions

View file

@ -0,0 +1,106 @@
from difflib import SequenceMatcher
from typing import List
ALLOWED_JUMP_SWAPS = (
("ja", "jb"),
("jae", "jbe"),
("jb", "ja"),
("jbe", "jae"),
("jg", "jl"),
("jge", "jle"),
("jl", "jg"),
("jle", "jge"),
("je", "je"),
("jne", "jne"),
)
def jump_swap_ok(a: str, b: str) -> bool:
"""For the instructions a,b, are they both jump instructions
that are compatible with a swapped cmp operand order?"""
# Grab the mnemonic
(jmp_a, _, __) = a.partition(" ")
(jmp_b, _, __) = b.partition(" ")
return (jmp_a, jmp_b) in ALLOWED_JUMP_SWAPS
def is_operand_swap(a: str, b: str) -> bool:
"""This is a hack to avoid parsing the operands. It's not as simple as
breaking on the comma because templates or string literals interfere
with this. Instead we check:
1. Do both strings use the exact same set of characters?
2. If we do break on ', ', is the first token of each different?
2 is needed to catch an edge case like:
cmp eax, dword ptr [ecx + 0x1234]
cmp ecx, dword ptr [eax + 0x1234]
"""
return a.partition(", ")[0] != b.partition(", ")[0] and sorted(a) == sorted(b)
def can_cmp_swap(orig: List[str], recomp: List[str]) -> bool:
# Make sure we have 1 cmp and 1 jmp for both
if len(orig) != 2 or len(recomp) != 2:
return False
if not orig[0].startswith("cmp") or not recomp[0].startswith("cmp"):
return False
if not orig[1].startswith("j") or not recomp[1].startswith("j"):
return False
# Checking two things:
# Are the cmp operands flipped?
# Is the jump instruction compatible with a flip?
return is_operand_swap(orig[0], recomp[0]) and jump_swap_ok(orig[1], recomp[1])
def patch_jump(a: str, b: str) -> str:
"""For jump instructions a, b, return `(mnemonic_a) (operand_b)`.
The reason to do it this way (instead of just returning `a`) is that
the jump instructions might use different displacement offsets
or labels. If we just replace `b` with `a`, this diff would be
incorrectly eliminated."""
(mnemonic_a, _, __) = a.partition(" ")
(_, __, operand_b) = b.partition(" ")
return mnemonic_a + " " + operand_b
def patch_cmp_swaps(
sm: SequenceMatcher, orig_asm: List[str], recomp_asm: List[str]
) -> bool:
"""Can we resolve the diffs between orig and recomp by patching
swapped cmp instructions?
For example:
cmp eax, ebx cmp ebx, eax
je .label je .label
cmp eax, ebx cmp ebx, eax
ja .label jb .label
"""
# Copy the instructions so we can patch
# TODO: If we change our strategy to allow multiple rounds of patching,
# we should modify the recomp array directly.
new_asm = recomp_asm[::]
codes = sm.get_opcodes()
for code, i1, i2, j1, j2 in codes:
# To save us the trouble of finding "compatible" cmp instructions
# use the diff information we already have.
if code != "replace":
continue
# If the ranges in orig and recomp are not equal, use the shorter one
for i, j in zip(range(i1, i2), range(j1, j2)):
if can_cmp_swap(orig_asm[i : i + 2], recomp_asm[j : j + 2]):
# Patch cmp
new_asm[j] = orig_asm[i]
# Patch the jump if necessary
new_asm[j + 1] = patch_jump(orig_asm[i + 1], recomp_asm[j + 1])
return orig_asm == new_asm

View file

@ -11,6 +11,7 @@
from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType
from isledecomp.compare.asm import ParseAsm, can_resolve_register_differences
from isledecomp.compare.asm.fixes import patch_cmp_swaps
from .db import CompareDb, MatchInfo
from .diff import combined_diff
from .lines import LinesDb
@ -470,7 +471,9 @@ def recomp_lookup(addr: int) -> Optional[str]:
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
is_effective_match = can_resolve_register_differences(orig_asm, recomp_asm)
is_effective_match = patch_cmp_swaps(
diff, orig_asm, recomp_asm
) or can_resolve_register_differences(orig_asm, recomp_asm)
unified_diff = combined_diff(
diff, orig_combined, recomp_combined, context_size=10
)