mirror of
https://github.com/isledecomp/isle.git
synced 2024-11-22 15:48:09 -05:00
Read floating point constants up front (#868)
* Read floating point constants before sanitize * Fix roadmap
This commit is contained in:
parent
7c6c68d6f9
commit
e7670f9a81
7 changed files with 97 additions and 47 deletions
|
@ -2,7 +2,7 @@
|
|||
import struct
|
||||
import bisect
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Iterator, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from collections import namedtuple
|
||||
|
||||
|
@ -77,6 +77,18 @@ def match_name(self, name: str) -> bool:
|
|||
def contains_vaddr(self, vaddr: int) -> bool:
|
||||
return self.virtual_address <= vaddr < self.virtual_address + self.extent
|
||||
|
||||
def read_virtual(self, vaddr: int, size: int) -> memoryview:
|
||||
ofs = vaddr - self.virtual_address
|
||||
|
||||
# Negative index will read from the end, which we don't want
|
||||
if ofs < 0:
|
||||
raise InvalidVirtualAddressError
|
||||
|
||||
try:
|
||||
return self.view[ofs : ofs + size]
|
||||
except IndexError as ex:
|
||||
raise InvalidVirtualAddressError from ex
|
||||
|
||||
def addr_is_uninitialized(self, vaddr: int) -> bool:
|
||||
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
|
||||
the characteristics field so instead we determine it this way."""
|
||||
|
@ -109,6 +121,7 @@ def __init__(self, filename: str, find_str: bool = False) -> None:
|
|||
self._section_vaddr: List[int] = []
|
||||
self.find_str = find_str
|
||||
self._potential_strings = {}
|
||||
self._relocations = set()
|
||||
self._relocated_addrs = set()
|
||||
self.imports = []
|
||||
self.thunks = []
|
||||
|
@ -279,11 +292,49 @@ def _populate_relocations(self):
|
|||
# We are now interested in the relocated addresses themselves. Seek to the
|
||||
# address where there is a relocation, then read the four bytes into our set.
|
||||
reloc_addrs.sort()
|
||||
self._relocations = set(reloc_addrs)
|
||||
|
||||
for section_id, offset in map(self.get_relative_addr, reloc_addrs):
|
||||
section = self.get_section_by_index(section_id)
|
||||
(relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4])
|
||||
self._relocated_addrs.add(relocated_addr)
|
||||
|
||||
def find_float_consts(self) -> Iterator[Tuple[int, int, float]]:
|
||||
"""Floating point instructions that refer to a memory address can
|
||||
point to constant values. Search the code sections to find FP
|
||||
instructions and check whether the pointer address refers to
|
||||
read-only data."""
|
||||
|
||||
# TODO: Should check any section that has code, not just .text
|
||||
text = self.get_section_by_name(".text")
|
||||
rdata = self.get_section_by_name(".rdata")
|
||||
|
||||
# These are the addresses where a relocation occurs.
|
||||
# Meaning: it points to an absolute address of something
|
||||
for addr in self._relocations:
|
||||
if not text.contains_vaddr(addr):
|
||||
continue
|
||||
|
||||
# Read the two bytes before the relocated address.
|
||||
# We will check against possible float opcodes
|
||||
raw = text.read_virtual(addr - 2, 6)
|
||||
(opcode, opcode_ext, const_addr) = struct.unpack("<BBL", raw)
|
||||
|
||||
# Skip right away if this is not const data
|
||||
if not rdata.contains_vaddr(const_addr):
|
||||
continue
|
||||
|
||||
if opcode_ext in (0x5, 0xD, 0x15, 0x1D, 0x25, 0x2D, 0x35, 0x3D):
|
||||
if opcode in (0xD8, 0xD9):
|
||||
# dword ptr -- single precision
|
||||
(float_value,) = struct.unpack("<f", self.read(const_addr, 4))
|
||||
yield (const_addr, 4, float_value)
|
||||
|
||||
elif opcode in (0xDC, 0xDD):
|
||||
# qword ptr -- double precision
|
||||
(float_value,) = struct.unpack("<d", self.read(const_addr, 8))
|
||||
yield (const_addr, 8, float_value)
|
||||
|
||||
def _populate_imports(self):
|
||||
"""Parse .idata to find imported DLLs and their functions."""
|
||||
idata_ofs = self.get_section_offset_by_name(".idata")
|
||||
|
|
|
@ -35,16 +35,6 @@ def from_hex(string: str) -> Optional[int]:
|
|||
return None
|
||||
|
||||
|
||||
def bytes_to_float(b: bytes) -> Optional[float]:
|
||||
if len(b) == 4:
|
||||
return struct.unpack("<f", b)[0]
|
||||
|
||||
if len(b) == 8:
|
||||
return struct.unpack("<d", b)[0]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def bytes_to_dword(b: bytes) -> Optional[int]:
|
||||
if len(b) == 4:
|
||||
return struct.unpack("<L", b)[0]
|
||||
|
@ -74,18 +64,6 @@ def is_relocated(self, addr: int) -> bool:
|
|||
|
||||
return False
|
||||
|
||||
def float_replace(self, addr: int, data_size: int) -> Optional[str]:
|
||||
if callable(self.bin_lookup):
|
||||
float_bytes = self.bin_lookup(addr, data_size)
|
||||
if float_bytes is None:
|
||||
return None
|
||||
|
||||
float_value = bytes_to_float(float_bytes)
|
||||
if float_value is not None:
|
||||
return f"{float_value} (FLOAT)"
|
||||
|
||||
return None
|
||||
|
||||
def lookup(
|
||||
self, addr: int, use_cache: bool = True, exact: bool = False
|
||||
) -> Optional[str]:
|
||||
|
@ -165,25 +143,6 @@ def hex_replace_indirect(self, match: re.Match) -> str:
|
|||
|
||||
return match.group(0).replace(match.group(1), self.replace(value))
|
||||
|
||||
def hex_replace_float(self, match: re.Match) -> str:
|
||||
"""Special case for replacements on float instructions.
|
||||
If the pointer is a float constant, read it from the binary."""
|
||||
value = int(match.group(1), 16)
|
||||
|
||||
# If we can find a variable name for this pointer, use it.
|
||||
placeholder = self.lookup(value)
|
||||
|
||||
# Read what's under the pointer and show the decimal value.
|
||||
if placeholder is None:
|
||||
float_size = 8 if "qword" in match.string else 4
|
||||
placeholder = self.float_replace(value, float_size)
|
||||
|
||||
# If we can't read the float, use a regular placeholder.
|
||||
if placeholder is None:
|
||||
placeholder = self.replace(value)
|
||||
|
||||
return match.group(0).replace(match.group(1), placeholder)
|
||||
|
||||
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
|
||||
# For jumps or calls, if the entire op_str is a hex number, the value
|
||||
# is a relative offset.
|
||||
|
@ -224,9 +183,6 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
|
|||
if inst.mnemonic == "call":
|
||||
# Special handling for absolute indirect CALL.
|
||||
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
|
||||
elif inst.mnemonic.startswith("f"):
|
||||
# If floating point instruction
|
||||
op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
|
||||
else:
|
||||
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)
|
||||
|
||||
|
|
|
@ -82,6 +82,7 @@ def __init__(
|
|||
self._load_cvdump()
|
||||
self._load_markers()
|
||||
self._find_original_strings()
|
||||
self._find_float_const()
|
||||
self._match_imports()
|
||||
self._match_exports()
|
||||
self._match_thunks()
|
||||
|
@ -249,6 +250,18 @@ def _find_original_strings(self):
|
|||
|
||||
self._db.match_string(addr, string)
|
||||
|
||||
def _find_float_const(self):
|
||||
"""Add floating point constants in each binary to the database.
|
||||
We are not matching anything right now because these values are not
|
||||
deduped like strings."""
|
||||
for addr, size, float_value in self.orig_bin.find_float_consts():
|
||||
self._db.set_orig_symbol(addr, SymbolType.FLOAT, str(float_value), size)
|
||||
|
||||
for addr, size, float_value in self.recomp_bin.find_float_consts():
|
||||
self._db.set_recomp_symbol(
|
||||
addr, SymbolType.FLOAT, str(float_value), None, size
|
||||
)
|
||||
|
||||
def _match_imports(self):
|
||||
"""We can match imported functions based on the DLL name and
|
||||
function symbol name."""
|
||||
|
|
|
@ -84,6 +84,23 @@ def __init__(self):
|
|||
self._db = sqlite3.connect(":memory:")
|
||||
self._db.executescript(_SETUP_SQL)
|
||||
|
||||
def set_orig_symbol(
|
||||
self,
|
||||
addr: int,
|
||||
compare_type: Optional[SymbolType],
|
||||
name: Optional[str],
|
||||
size: Optional[int],
|
||||
):
|
||||
# Ignore collisions here.
|
||||
if self._orig_used(addr):
|
||||
return
|
||||
|
||||
compare_value = compare_type.value if compare_type is not None else None
|
||||
self._db.execute(
|
||||
"INSERT INTO `symbols` (orig_addr, compare_type, name, size) VALUES (?,?,?,?)",
|
||||
(addr, compare_value, name, size),
|
||||
)
|
||||
|
||||
def set_recomp_symbol(
|
||||
self,
|
||||
addr: int,
|
||||
|
|
|
@ -10,3 +10,4 @@ class SymbolType(Enum):
|
|||
POINTER = 3
|
||||
STRING = 4
|
||||
VTABLE = 5
|
||||
FLOAT = 6
|
||||
|
|
|
@ -189,6 +189,7 @@ def substitute_1234(addr: int, _: bool) -> Optional[str]:
|
|||
assert op_str == "0x5555"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="changed implementation")
|
||||
def test_float_replacement():
|
||||
"""Floating point constants often appear as pointers to data.
|
||||
A good example is ViewROI::IntrinsicImportance and the subclass override
|
||||
|
@ -208,6 +209,7 @@ def bin_lookup(addr: int, _: int) -> Optional[bytes]:
|
|||
assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="changed implementation")
|
||||
def test_float_variable():
|
||||
"""If there is a variable at the address referenced by a float instruction,
|
||||
use the name instead of calling into the float replacement handler."""
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
from typing import Iterator, List, Optional, Tuple
|
||||
from collections import namedtuple
|
||||
from isledecomp import Bin as IsleBin
|
||||
from isledecomp.bin import InvalidVirtualAddressError
|
||||
from isledecomp.cvdump import Cvdump
|
||||
from isledecomp.compare import Compare as IsleCompare
|
||||
from isledecomp.types import SymbolType
|
||||
|
@ -87,7 +88,7 @@ def print_sections(sections):
|
|||
print()
|
||||
|
||||
|
||||
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta"]
|
||||
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta", "flo"]
|
||||
|
||||
|
||||
def match_type_abbreviation(mtype: Optional[SymbolType]) -> str:
|
||||
|
@ -456,7 +457,16 @@ def to_roadmap_row(match):
|
|||
module_name,
|
||||
)
|
||||
|
||||
results = list(map(to_roadmap_row, engine.get_all()))
|
||||
def roadmap_row_generator(matches):
|
||||
for match in matches:
|
||||
try:
|
||||
yield to_roadmap_row(match)
|
||||
except InvalidVirtualAddressError:
|
||||
# This is here to work around the fact that we have RVA
|
||||
# values (i.e. not real virtual addrs) in our compare db.
|
||||
pass
|
||||
|
||||
results = list(roadmap_row_generator(engine.get_all()))
|
||||
|
||||
if args.order is not None:
|
||||
suggest_order(results, module_map, args.order)
|
||||
|
|
Loading…
Reference in a new issue