Read floating point constants up front (#868)

* Read floating point constants before sanitize

* Fix roadmap
This commit is contained in:
MS 2024-04-29 14:33:16 -04:00 committed by GitHub
parent 7c6c68d6f9
commit e7670f9a81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 97 additions and 47 deletions

View file

@ -2,7 +2,7 @@
import struct
import bisect
from functools import cached_property
from typing import List, Optional, Tuple
from typing import Iterator, List, Optional, Tuple
from dataclasses import dataclass
from collections import namedtuple
@ -77,6 +77,18 @@ def match_name(self, name: str) -> bool:
def contains_vaddr(self, vaddr: int) -> bool:
return self.virtual_address <= vaddr < self.virtual_address + self.extent
def read_virtual(self, vaddr: int, size: int) -> memoryview:
ofs = vaddr - self.virtual_address
# Negative index will read from the end, which we don't want
if ofs < 0:
raise InvalidVirtualAddressError
try:
return self.view[ofs : ofs + size]
except IndexError as ex:
raise InvalidVirtualAddressError from ex
def addr_is_uninitialized(self, vaddr: int) -> bool:
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
the characteristics field so instead we determine it this way."""
@ -109,6 +121,7 @@ def __init__(self, filename: str, find_str: bool = False) -> None:
self._section_vaddr: List[int] = []
self.find_str = find_str
self._potential_strings = {}
self._relocations = set()
self._relocated_addrs = set()
self.imports = []
self.thunks = []
@ -279,11 +292,49 @@ def _populate_relocations(self):
# We are now interested in the relocated addresses themselves. Seek to the
# address where there is a relocation, then read the four bytes into our set.
reloc_addrs.sort()
self._relocations = set(reloc_addrs)
for section_id, offset in map(self.get_relative_addr, reloc_addrs):
section = self.get_section_by_index(section_id)
(relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4])
self._relocated_addrs.add(relocated_addr)
def find_float_consts(self) -> Iterator[Tuple[int, int, float]]:
"""Floating point instructions that refer to a memory address can
point to constant values. Search the code sections to find FP
instructions and check whether the pointer address refers to
read-only data."""
# TODO: Should check any section that has code, not just .text
text = self.get_section_by_name(".text")
rdata = self.get_section_by_name(".rdata")
# These are the addresses where a relocation occurs.
# Meaning: it points to an absolute address of something
for addr in self._relocations:
if not text.contains_vaddr(addr):
continue
# Read the two bytes before the relocated address.
# We will check against possible float opcodes
raw = text.read_virtual(addr - 2, 6)
(opcode, opcode_ext, const_addr) = struct.unpack("<BBL", raw)
# Skip right away if this is not const data
if not rdata.contains_vaddr(const_addr):
continue
if opcode_ext in (0x5, 0xD, 0x15, 0x1D, 0x25, 0x2D, 0x35, 0x3D):
if opcode in (0xD8, 0xD9):
# dword ptr -- single precision
(float_value,) = struct.unpack("<f", self.read(const_addr, 4))
yield (const_addr, 4, float_value)
elif opcode in (0xDC, 0xDD):
# qword ptr -- double precision
(float_value,) = struct.unpack("<d", self.read(const_addr, 8))
yield (const_addr, 8, float_value)
def _populate_imports(self):
"""Parse .idata to find imported DLLs and their functions."""
idata_ofs = self.get_section_offset_by_name(".idata")

View file

@ -35,16 +35,6 @@ def from_hex(string: str) -> Optional[int]:
return None
def bytes_to_float(b: bytes) -> Optional[float]:
if len(b) == 4:
return struct.unpack("<f", b)[0]
if len(b) == 8:
return struct.unpack("<d", b)[0]
return None
def bytes_to_dword(b: bytes) -> Optional[int]:
if len(b) == 4:
return struct.unpack("<L", b)[0]
@ -74,18 +64,6 @@ def is_relocated(self, addr: int) -> bool:
return False
def float_replace(self, addr: int, data_size: int) -> Optional[str]:
if callable(self.bin_lookup):
float_bytes = self.bin_lookup(addr, data_size)
if float_bytes is None:
return None
float_value = bytes_to_float(float_bytes)
if float_value is not None:
return f"{float_value} (FLOAT)"
return None
def lookup(
self, addr: int, use_cache: bool = True, exact: bool = False
) -> Optional[str]:
@ -165,25 +143,6 @@ def hex_replace_indirect(self, match: re.Match) -> str:
return match.group(0).replace(match.group(1), self.replace(value))
def hex_replace_float(self, match: re.Match) -> str:
"""Special case for replacements on float instructions.
If the pointer is a float constant, read it from the binary."""
value = int(match.group(1), 16)
# If we can find a variable name for this pointer, use it.
placeholder = self.lookup(value)
# Read what's under the pointer and show the decimal value.
if placeholder is None:
float_size = 8 if "qword" in match.string else 4
placeholder = self.float_replace(value, float_size)
# If we can't read the float, use a regular placeholder.
if placeholder is None:
placeholder = self.replace(value)
return match.group(0).replace(match.group(1), placeholder)
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset.
@ -224,9 +183,6 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
if inst.mnemonic == "call":
# Special handling for absolute indirect CALL.
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
elif inst.mnemonic.startswith("f"):
# If floating point instruction
op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
else:
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)

View file

@ -82,6 +82,7 @@ def __init__(
self._load_cvdump()
self._load_markers()
self._find_original_strings()
self._find_float_const()
self._match_imports()
self._match_exports()
self._match_thunks()
@ -249,6 +250,18 @@ def _find_original_strings(self):
self._db.match_string(addr, string)
def _find_float_const(self):
"""Add floating point constants in each binary to the database.
We are not matching anything right now because these values are not
deduped like strings."""
for addr, size, float_value in self.orig_bin.find_float_consts():
self._db.set_orig_symbol(addr, SymbolType.FLOAT, str(float_value), size)
for addr, size, float_value in self.recomp_bin.find_float_consts():
self._db.set_recomp_symbol(
addr, SymbolType.FLOAT, str(float_value), None, size
)
def _match_imports(self):
"""We can match imported functions based on the DLL name and
function symbol name."""

View file

@ -84,6 +84,23 @@ def __init__(self):
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
def set_orig_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
size: Optional[int],
):
# Ignore collisions here.
if self._orig_used(addr):
return
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (orig_addr, compare_type, name, size) VALUES (?,?,?,?)",
(addr, compare_value, name, size),
)
def set_recomp_symbol(
self,
addr: int,

View file

@ -10,3 +10,4 @@ class SymbolType(Enum):
POINTER = 3
STRING = 4
VTABLE = 5
FLOAT = 6

View file

@ -189,6 +189,7 @@ def substitute_1234(addr: int, _: bool) -> Optional[str]:
assert op_str == "0x5555"
@pytest.mark.skip(reason="changed implementation")
def test_float_replacement():
"""Floating point constants often appear as pointers to data.
A good example is ViewROI::IntrinsicImportance and the subclass override
@ -208,6 +209,7 @@ def bin_lookup(addr: int, _: int) -> Optional[bytes]:
assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]"
@pytest.mark.skip(reason="changed implementation")
def test_float_variable():
"""If there is a variable at the address referenced by a float instruction,
use the name instead of calling into the float replacement handler."""

View file

@ -10,6 +10,7 @@
from typing import Iterator, List, Optional, Tuple
from collections import namedtuple
from isledecomp import Bin as IsleBin
from isledecomp.bin import InvalidVirtualAddressError
from isledecomp.cvdump import Cvdump
from isledecomp.compare import Compare as IsleCompare
from isledecomp.types import SymbolType
@ -87,7 +88,7 @@ def print_sections(sections):
print()
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta"]
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta", "flo"]
def match_type_abbreviation(mtype: Optional[SymbolType]) -> str:
@ -456,7 +457,16 @@ def to_roadmap_row(match):
module_name,
)
results = list(map(to_roadmap_row, engine.get_all()))
def roadmap_row_generator(matches):
for match in matches:
try:
yield to_roadmap_row(match)
except InvalidVirtualAddressError:
# This is here to work around the fact that we have RVA
# values (i.e. not real virtual addrs) in our compare db.
pass
results = list(roadmap_row_generator(engine.get_all()))
if args.order is not None:
suggest_order(results, module_map, args.order)