Read floating point constants up front (#868)

* Read floating point constants before sanitize

* Fix roadmap
This commit is contained in:
MS 2024-04-29 14:33:16 -04:00 committed by GitHub
parent 7c6c68d6f9
commit e7670f9a81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 97 additions and 47 deletions

View file

@ -2,7 +2,7 @@
import struct import struct
import bisect import bisect
from functools import cached_property from functools import cached_property
from typing import List, Optional, Tuple from typing import Iterator, List, Optional, Tuple
from dataclasses import dataclass from dataclasses import dataclass
from collections import namedtuple from collections import namedtuple
@ -77,6 +77,18 @@ def match_name(self, name: str) -> bool:
def contains_vaddr(self, vaddr: int) -> bool: def contains_vaddr(self, vaddr: int) -> bool:
return self.virtual_address <= vaddr < self.virtual_address + self.extent return self.virtual_address <= vaddr < self.virtual_address + self.extent
def read_virtual(self, vaddr: int, size: int) -> memoryview:
ofs = vaddr - self.virtual_address
# Negative index will read from the end, which we don't want
if ofs < 0:
raise InvalidVirtualAddressError
try:
return self.view[ofs : ofs + size]
except IndexError as ex:
raise InvalidVirtualAddressError from ex
def addr_is_uninitialized(self, vaddr: int) -> bool: def addr_is_uninitialized(self, vaddr: int) -> bool:
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in """We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
the characteristics field so instead we determine it this way.""" the characteristics field so instead we determine it this way."""
@ -109,6 +121,7 @@ def __init__(self, filename: str, find_str: bool = False) -> None:
self._section_vaddr: List[int] = [] self._section_vaddr: List[int] = []
self.find_str = find_str self.find_str = find_str
self._potential_strings = {} self._potential_strings = {}
self._relocations = set()
self._relocated_addrs = set() self._relocated_addrs = set()
self.imports = [] self.imports = []
self.thunks = [] self.thunks = []
@ -279,11 +292,49 @@ def _populate_relocations(self):
# We are now interested in the relocated addresses themselves. Seek to the # We are now interested in the relocated addresses themselves. Seek to the
# address where there is a relocation, then read the four bytes into our set. # address where there is a relocation, then read the four bytes into our set.
reloc_addrs.sort() reloc_addrs.sort()
self._relocations = set(reloc_addrs)
for section_id, offset in map(self.get_relative_addr, reloc_addrs): for section_id, offset in map(self.get_relative_addr, reloc_addrs):
section = self.get_section_by_index(section_id) section = self.get_section_by_index(section_id)
(relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4]) (relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4])
self._relocated_addrs.add(relocated_addr) self._relocated_addrs.add(relocated_addr)
def find_float_consts(self) -> Iterator[Tuple[int, int, float]]:
"""Floating point instructions that refer to a memory address can
point to constant values. Search the code sections to find FP
instructions and check whether the pointer address refers to
read-only data."""
# TODO: Should check any section that has code, not just .text
text = self.get_section_by_name(".text")
rdata = self.get_section_by_name(".rdata")
# These are the addresses where a relocation occurs.
# Meaning: it points to an absolute address of something
for addr in self._relocations:
if not text.contains_vaddr(addr):
continue
# Read the two bytes before the relocated address.
# We will check against possible float opcodes
raw = text.read_virtual(addr - 2, 6)
(opcode, opcode_ext, const_addr) = struct.unpack("<BBL", raw)
# Skip right away if this is not const data
if not rdata.contains_vaddr(const_addr):
continue
if opcode_ext in (0x5, 0xD, 0x15, 0x1D, 0x25, 0x2D, 0x35, 0x3D):
if opcode in (0xD8, 0xD9):
# dword ptr -- single precision
(float_value,) = struct.unpack("<f", self.read(const_addr, 4))
yield (const_addr, 4, float_value)
elif opcode in (0xDC, 0xDD):
# qword ptr -- double precision
(float_value,) = struct.unpack("<d", self.read(const_addr, 8))
yield (const_addr, 8, float_value)
def _populate_imports(self): def _populate_imports(self):
"""Parse .idata to find imported DLLs and their functions.""" """Parse .idata to find imported DLLs and their functions."""
idata_ofs = self.get_section_offset_by_name(".idata") idata_ofs = self.get_section_offset_by_name(".idata")

View file

@ -35,16 +35,6 @@ def from_hex(string: str) -> Optional[int]:
return None return None
def bytes_to_float(b: bytes) -> Optional[float]:
if len(b) == 4:
return struct.unpack("<f", b)[0]
if len(b) == 8:
return struct.unpack("<d", b)[0]
return None
def bytes_to_dword(b: bytes) -> Optional[int]: def bytes_to_dword(b: bytes) -> Optional[int]:
if len(b) == 4: if len(b) == 4:
return struct.unpack("<L", b)[0] return struct.unpack("<L", b)[0]
@ -74,18 +64,6 @@ def is_relocated(self, addr: int) -> bool:
return False return False
def float_replace(self, addr: int, data_size: int) -> Optional[str]:
if callable(self.bin_lookup):
float_bytes = self.bin_lookup(addr, data_size)
if float_bytes is None:
return None
float_value = bytes_to_float(float_bytes)
if float_value is not None:
return f"{float_value} (FLOAT)"
return None
def lookup( def lookup(
self, addr: int, use_cache: bool = True, exact: bool = False self, addr: int, use_cache: bool = True, exact: bool = False
) -> Optional[str]: ) -> Optional[str]:
@ -165,25 +143,6 @@ def hex_replace_indirect(self, match: re.Match) -> str:
return match.group(0).replace(match.group(1), self.replace(value)) return match.group(0).replace(match.group(1), self.replace(value))
def hex_replace_float(self, match: re.Match) -> str:
"""Special case for replacements on float instructions.
If the pointer is a float constant, read it from the binary."""
value = int(match.group(1), 16)
# If we can find a variable name for this pointer, use it.
placeholder = self.lookup(value)
# Read what's under the pointer and show the decimal value.
if placeholder is None:
float_size = 8 if "qword" in match.string else 4
placeholder = self.float_replace(value, float_size)
# If we can't read the float, use a regular placeholder.
if placeholder is None:
placeholder = self.replace(value)
return match.group(0).replace(match.group(1), placeholder)
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]: def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# For jumps or calls, if the entire op_str is a hex number, the value # For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset. # is a relative offset.
@ -224,9 +183,6 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
if inst.mnemonic == "call": if inst.mnemonic == "call":
# Special handling for absolute indirect CALL. # Special handling for absolute indirect CALL.
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str) op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
elif inst.mnemonic.startswith("f"):
# If floating point instruction
op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
else: else:
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str) op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)

View file

@ -82,6 +82,7 @@ def __init__(
self._load_cvdump() self._load_cvdump()
self._load_markers() self._load_markers()
self._find_original_strings() self._find_original_strings()
self._find_float_const()
self._match_imports() self._match_imports()
self._match_exports() self._match_exports()
self._match_thunks() self._match_thunks()
@ -249,6 +250,18 @@ def _find_original_strings(self):
self._db.match_string(addr, string) self._db.match_string(addr, string)
def _find_float_const(self):
"""Add floating point constants in each binary to the database.
We are not matching anything right now because these values are not
deduped like strings."""
for addr, size, float_value in self.orig_bin.find_float_consts():
self._db.set_orig_symbol(addr, SymbolType.FLOAT, str(float_value), size)
for addr, size, float_value in self.recomp_bin.find_float_consts():
self._db.set_recomp_symbol(
addr, SymbolType.FLOAT, str(float_value), None, size
)
def _match_imports(self): def _match_imports(self):
"""We can match imported functions based on the DLL name and """We can match imported functions based on the DLL name and
function symbol name.""" function symbol name."""

View file

@ -84,6 +84,23 @@ def __init__(self):
self._db = sqlite3.connect(":memory:") self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL) self._db.executescript(_SETUP_SQL)
def set_orig_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
size: Optional[int],
):
# Ignore collisions here.
if self._orig_used(addr):
return
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (orig_addr, compare_type, name, size) VALUES (?,?,?,?)",
(addr, compare_value, name, size),
)
def set_recomp_symbol( def set_recomp_symbol(
self, self,
addr: int, addr: int,

View file

@ -10,3 +10,4 @@ class SymbolType(Enum):
POINTER = 3 POINTER = 3
STRING = 4 STRING = 4
VTABLE = 5 VTABLE = 5
FLOAT = 6

View file

@ -189,6 +189,7 @@ def substitute_1234(addr: int, _: bool) -> Optional[str]:
assert op_str == "0x5555" assert op_str == "0x5555"
@pytest.mark.skip(reason="changed implementation")
def test_float_replacement(): def test_float_replacement():
"""Floating point constants often appear as pointers to data. """Floating point constants often appear as pointers to data.
A good example is ViewROI::IntrinsicImportance and the subclass override A good example is ViewROI::IntrinsicImportance and the subclass override
@ -208,6 +209,7 @@ def bin_lookup(addr: int, _: int) -> Optional[bytes]:
assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]" assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]"
@pytest.mark.skip(reason="changed implementation")
def test_float_variable(): def test_float_variable():
"""If there is a variable at the address referenced by a float instruction, """If there is a variable at the address referenced by a float instruction,
use the name instead of calling into the float replacement handler.""" use the name instead of calling into the float replacement handler."""

View file

@ -10,6 +10,7 @@
from typing import Iterator, List, Optional, Tuple from typing import Iterator, List, Optional, Tuple
from collections import namedtuple from collections import namedtuple
from isledecomp import Bin as IsleBin from isledecomp import Bin as IsleBin
from isledecomp.bin import InvalidVirtualAddressError
from isledecomp.cvdump import Cvdump from isledecomp.cvdump import Cvdump
from isledecomp.compare import Compare as IsleCompare from isledecomp.compare import Compare as IsleCompare
from isledecomp.types import SymbolType from isledecomp.types import SymbolType
@ -87,7 +88,7 @@ def print_sections(sections):
print() print()
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta"] ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta", "flo"]
def match_type_abbreviation(mtype: Optional[SymbolType]) -> str: def match_type_abbreviation(mtype: Optional[SymbolType]) -> str:
@ -456,7 +457,16 @@ def to_roadmap_row(match):
module_name, module_name,
) )
results = list(map(to_roadmap_row, engine.get_all())) def roadmap_row_generator(matches):
for match in matches:
try:
yield to_roadmap_row(match)
except InvalidVirtualAddressError:
# This is here to work around the fact that we have RVA
# values (i.e. not real virtual addrs) in our compare db.
pass
results = list(roadmap_row_generator(engine.get_all()))
if args.order is not None: if args.order is not None:
suggest_order(results, module_map, args.order) suggest_order(results, module_map, args.order)