mirror of
https://github.com/isledecomp/isle.git
synced 2024-11-22 15:48:09 -05:00
Read floating point constants up front (#868)
* Read floating point constants before sanitize * Fix roadmap
This commit is contained in:
parent
7c6c68d6f9
commit
e7670f9a81
7 changed files with 97 additions and 47 deletions
|
@ -2,7 +2,7 @@
|
||||||
import struct
|
import struct
|
||||||
import bisect
|
import bisect
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import List, Optional, Tuple
|
from typing import Iterator, List, Optional, Tuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
|
@ -77,6 +77,18 @@ def match_name(self, name: str) -> bool:
|
||||||
def contains_vaddr(self, vaddr: int) -> bool:
|
def contains_vaddr(self, vaddr: int) -> bool:
|
||||||
return self.virtual_address <= vaddr < self.virtual_address + self.extent
|
return self.virtual_address <= vaddr < self.virtual_address + self.extent
|
||||||
|
|
||||||
|
def read_virtual(self, vaddr: int, size: int) -> memoryview:
|
||||||
|
ofs = vaddr - self.virtual_address
|
||||||
|
|
||||||
|
# Negative index will read from the end, which we don't want
|
||||||
|
if ofs < 0:
|
||||||
|
raise InvalidVirtualAddressError
|
||||||
|
|
||||||
|
try:
|
||||||
|
return self.view[ofs : ofs + size]
|
||||||
|
except IndexError as ex:
|
||||||
|
raise InvalidVirtualAddressError from ex
|
||||||
|
|
||||||
def addr_is_uninitialized(self, vaddr: int) -> bool:
|
def addr_is_uninitialized(self, vaddr: int) -> bool:
|
||||||
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
|
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
|
||||||
the characteristics field so instead we determine it this way."""
|
the characteristics field so instead we determine it this way."""
|
||||||
|
@ -109,6 +121,7 @@ def __init__(self, filename: str, find_str: bool = False) -> None:
|
||||||
self._section_vaddr: List[int] = []
|
self._section_vaddr: List[int] = []
|
||||||
self.find_str = find_str
|
self.find_str = find_str
|
||||||
self._potential_strings = {}
|
self._potential_strings = {}
|
||||||
|
self._relocations = set()
|
||||||
self._relocated_addrs = set()
|
self._relocated_addrs = set()
|
||||||
self.imports = []
|
self.imports = []
|
||||||
self.thunks = []
|
self.thunks = []
|
||||||
|
@ -279,11 +292,49 @@ def _populate_relocations(self):
|
||||||
# We are now interested in the relocated addresses themselves. Seek to the
|
# We are now interested in the relocated addresses themselves. Seek to the
|
||||||
# address where there is a relocation, then read the four bytes into our set.
|
# address where there is a relocation, then read the four bytes into our set.
|
||||||
reloc_addrs.sort()
|
reloc_addrs.sort()
|
||||||
|
self._relocations = set(reloc_addrs)
|
||||||
|
|
||||||
for section_id, offset in map(self.get_relative_addr, reloc_addrs):
|
for section_id, offset in map(self.get_relative_addr, reloc_addrs):
|
||||||
section = self.get_section_by_index(section_id)
|
section = self.get_section_by_index(section_id)
|
||||||
(relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4])
|
(relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4])
|
||||||
self._relocated_addrs.add(relocated_addr)
|
self._relocated_addrs.add(relocated_addr)
|
||||||
|
|
||||||
|
def find_float_consts(self) -> Iterator[Tuple[int, int, float]]:
|
||||||
|
"""Floating point instructions that refer to a memory address can
|
||||||
|
point to constant values. Search the code sections to find FP
|
||||||
|
instructions and check whether the pointer address refers to
|
||||||
|
read-only data."""
|
||||||
|
|
||||||
|
# TODO: Should check any section that has code, not just .text
|
||||||
|
text = self.get_section_by_name(".text")
|
||||||
|
rdata = self.get_section_by_name(".rdata")
|
||||||
|
|
||||||
|
# These are the addresses where a relocation occurs.
|
||||||
|
# Meaning: it points to an absolute address of something
|
||||||
|
for addr in self._relocations:
|
||||||
|
if not text.contains_vaddr(addr):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Read the two bytes before the relocated address.
|
||||||
|
# We will check against possible float opcodes
|
||||||
|
raw = text.read_virtual(addr - 2, 6)
|
||||||
|
(opcode, opcode_ext, const_addr) = struct.unpack("<BBL", raw)
|
||||||
|
|
||||||
|
# Skip right away if this is not const data
|
||||||
|
if not rdata.contains_vaddr(const_addr):
|
||||||
|
continue
|
||||||
|
|
||||||
|
if opcode_ext in (0x5, 0xD, 0x15, 0x1D, 0x25, 0x2D, 0x35, 0x3D):
|
||||||
|
if opcode in (0xD8, 0xD9):
|
||||||
|
# dword ptr -- single precision
|
||||||
|
(float_value,) = struct.unpack("<f", self.read(const_addr, 4))
|
||||||
|
yield (const_addr, 4, float_value)
|
||||||
|
|
||||||
|
elif opcode in (0xDC, 0xDD):
|
||||||
|
# qword ptr -- double precision
|
||||||
|
(float_value,) = struct.unpack("<d", self.read(const_addr, 8))
|
||||||
|
yield (const_addr, 8, float_value)
|
||||||
|
|
||||||
def _populate_imports(self):
|
def _populate_imports(self):
|
||||||
"""Parse .idata to find imported DLLs and their functions."""
|
"""Parse .idata to find imported DLLs and their functions."""
|
||||||
idata_ofs = self.get_section_offset_by_name(".idata")
|
idata_ofs = self.get_section_offset_by_name(".idata")
|
||||||
|
|
|
@ -35,16 +35,6 @@ def from_hex(string: str) -> Optional[int]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_float(b: bytes) -> Optional[float]:
|
|
||||||
if len(b) == 4:
|
|
||||||
return struct.unpack("<f", b)[0]
|
|
||||||
|
|
||||||
if len(b) == 8:
|
|
||||||
return struct.unpack("<d", b)[0]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_to_dword(b: bytes) -> Optional[int]:
|
def bytes_to_dword(b: bytes) -> Optional[int]:
|
||||||
if len(b) == 4:
|
if len(b) == 4:
|
||||||
return struct.unpack("<L", b)[0]
|
return struct.unpack("<L", b)[0]
|
||||||
|
@ -74,18 +64,6 @@ def is_relocated(self, addr: int) -> bool:
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def float_replace(self, addr: int, data_size: int) -> Optional[str]:
|
|
||||||
if callable(self.bin_lookup):
|
|
||||||
float_bytes = self.bin_lookup(addr, data_size)
|
|
||||||
if float_bytes is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
float_value = bytes_to_float(float_bytes)
|
|
||||||
if float_value is not None:
|
|
||||||
return f"{float_value} (FLOAT)"
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def lookup(
|
def lookup(
|
||||||
self, addr: int, use_cache: bool = True, exact: bool = False
|
self, addr: int, use_cache: bool = True, exact: bool = False
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
|
@ -165,25 +143,6 @@ def hex_replace_indirect(self, match: re.Match) -> str:
|
||||||
|
|
||||||
return match.group(0).replace(match.group(1), self.replace(value))
|
return match.group(0).replace(match.group(1), self.replace(value))
|
||||||
|
|
||||||
def hex_replace_float(self, match: re.Match) -> str:
|
|
||||||
"""Special case for replacements on float instructions.
|
|
||||||
If the pointer is a float constant, read it from the binary."""
|
|
||||||
value = int(match.group(1), 16)
|
|
||||||
|
|
||||||
# If we can find a variable name for this pointer, use it.
|
|
||||||
placeholder = self.lookup(value)
|
|
||||||
|
|
||||||
# Read what's under the pointer and show the decimal value.
|
|
||||||
if placeholder is None:
|
|
||||||
float_size = 8 if "qword" in match.string else 4
|
|
||||||
placeholder = self.float_replace(value, float_size)
|
|
||||||
|
|
||||||
# If we can't read the float, use a regular placeholder.
|
|
||||||
if placeholder is None:
|
|
||||||
placeholder = self.replace(value)
|
|
||||||
|
|
||||||
return match.group(0).replace(match.group(1), placeholder)
|
|
||||||
|
|
||||||
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
|
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
|
||||||
# For jumps or calls, if the entire op_str is a hex number, the value
|
# For jumps or calls, if the entire op_str is a hex number, the value
|
||||||
# is a relative offset.
|
# is a relative offset.
|
||||||
|
@ -224,9 +183,6 @@ def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
|
||||||
if inst.mnemonic == "call":
|
if inst.mnemonic == "call":
|
||||||
# Special handling for absolute indirect CALL.
|
# Special handling for absolute indirect CALL.
|
||||||
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
|
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
|
||||||
elif inst.mnemonic.startswith("f"):
|
|
||||||
# If floating point instruction
|
|
||||||
op_str = ptr_replace_regex.sub(self.hex_replace_float, inst.op_str)
|
|
||||||
else:
|
else:
|
||||||
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)
|
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)
|
||||||
|
|
||||||
|
|
|
@ -82,6 +82,7 @@ def __init__(
|
||||||
self._load_cvdump()
|
self._load_cvdump()
|
||||||
self._load_markers()
|
self._load_markers()
|
||||||
self._find_original_strings()
|
self._find_original_strings()
|
||||||
|
self._find_float_const()
|
||||||
self._match_imports()
|
self._match_imports()
|
||||||
self._match_exports()
|
self._match_exports()
|
||||||
self._match_thunks()
|
self._match_thunks()
|
||||||
|
@ -249,6 +250,18 @@ def _find_original_strings(self):
|
||||||
|
|
||||||
self._db.match_string(addr, string)
|
self._db.match_string(addr, string)
|
||||||
|
|
||||||
|
def _find_float_const(self):
|
||||||
|
"""Add floating point constants in each binary to the database.
|
||||||
|
We are not matching anything right now because these values are not
|
||||||
|
deduped like strings."""
|
||||||
|
for addr, size, float_value in self.orig_bin.find_float_consts():
|
||||||
|
self._db.set_orig_symbol(addr, SymbolType.FLOAT, str(float_value), size)
|
||||||
|
|
||||||
|
for addr, size, float_value in self.recomp_bin.find_float_consts():
|
||||||
|
self._db.set_recomp_symbol(
|
||||||
|
addr, SymbolType.FLOAT, str(float_value), None, size
|
||||||
|
)
|
||||||
|
|
||||||
def _match_imports(self):
|
def _match_imports(self):
|
||||||
"""We can match imported functions based on the DLL name and
|
"""We can match imported functions based on the DLL name and
|
||||||
function symbol name."""
|
function symbol name."""
|
||||||
|
|
|
@ -84,6 +84,23 @@ def __init__(self):
|
||||||
self._db = sqlite3.connect(":memory:")
|
self._db = sqlite3.connect(":memory:")
|
||||||
self._db.executescript(_SETUP_SQL)
|
self._db.executescript(_SETUP_SQL)
|
||||||
|
|
||||||
|
def set_orig_symbol(
|
||||||
|
self,
|
||||||
|
addr: int,
|
||||||
|
compare_type: Optional[SymbolType],
|
||||||
|
name: Optional[str],
|
||||||
|
size: Optional[int],
|
||||||
|
):
|
||||||
|
# Ignore collisions here.
|
||||||
|
if self._orig_used(addr):
|
||||||
|
return
|
||||||
|
|
||||||
|
compare_value = compare_type.value if compare_type is not None else None
|
||||||
|
self._db.execute(
|
||||||
|
"INSERT INTO `symbols` (orig_addr, compare_type, name, size) VALUES (?,?,?,?)",
|
||||||
|
(addr, compare_value, name, size),
|
||||||
|
)
|
||||||
|
|
||||||
def set_recomp_symbol(
|
def set_recomp_symbol(
|
||||||
self,
|
self,
|
||||||
addr: int,
|
addr: int,
|
||||||
|
|
|
@ -10,3 +10,4 @@ class SymbolType(Enum):
|
||||||
POINTER = 3
|
POINTER = 3
|
||||||
STRING = 4
|
STRING = 4
|
||||||
VTABLE = 5
|
VTABLE = 5
|
||||||
|
FLOAT = 6
|
||||||
|
|
|
@ -189,6 +189,7 @@ def substitute_1234(addr: int, _: bool) -> Optional[str]:
|
||||||
assert op_str == "0x5555"
|
assert op_str == "0x5555"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="changed implementation")
|
||||||
def test_float_replacement():
|
def test_float_replacement():
|
||||||
"""Floating point constants often appear as pointers to data.
|
"""Floating point constants often appear as pointers to data.
|
||||||
A good example is ViewROI::IntrinsicImportance and the subclass override
|
A good example is ViewROI::IntrinsicImportance and the subclass override
|
||||||
|
@ -208,6 +209,7 @@ def bin_lookup(addr: int, _: int) -> Optional[bytes]:
|
||||||
assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]"
|
assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="changed implementation")
|
||||||
def test_float_variable():
|
def test_float_variable():
|
||||||
"""If there is a variable at the address referenced by a float instruction,
|
"""If there is a variable at the address referenced by a float instruction,
|
||||||
use the name instead of calling into the float replacement handler."""
|
use the name instead of calling into the float replacement handler."""
|
||||||
|
|
|
@ -10,6 +10,7 @@
|
||||||
from typing import Iterator, List, Optional, Tuple
|
from typing import Iterator, List, Optional, Tuple
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from isledecomp import Bin as IsleBin
|
from isledecomp import Bin as IsleBin
|
||||||
|
from isledecomp.bin import InvalidVirtualAddressError
|
||||||
from isledecomp.cvdump import Cvdump
|
from isledecomp.cvdump import Cvdump
|
||||||
from isledecomp.compare import Compare as IsleCompare
|
from isledecomp.compare import Compare as IsleCompare
|
||||||
from isledecomp.types import SymbolType
|
from isledecomp.types import SymbolType
|
||||||
|
@ -87,7 +88,7 @@ def print_sections(sections):
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta"]
|
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta", "flo"]
|
||||||
|
|
||||||
|
|
||||||
def match_type_abbreviation(mtype: Optional[SymbolType]) -> str:
|
def match_type_abbreviation(mtype: Optional[SymbolType]) -> str:
|
||||||
|
@ -456,7 +457,16 @@ def to_roadmap_row(match):
|
||||||
module_name,
|
module_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
results = list(map(to_roadmap_row, engine.get_all()))
|
def roadmap_row_generator(matches):
|
||||||
|
for match in matches:
|
||||||
|
try:
|
||||||
|
yield to_roadmap_row(match)
|
||||||
|
except InvalidVirtualAddressError:
|
||||||
|
# This is here to work around the fact that we have RVA
|
||||||
|
# values (i.e. not real virtual addrs) in our compare db.
|
||||||
|
pass
|
||||||
|
|
||||||
|
results = list(roadmap_row_generator(engine.get_all()))
|
||||||
|
|
||||||
if args.order is not None:
|
if args.order is not None:
|
||||||
suggest_order(results, module_map, args.order)
|
suggest_order(results, module_map, args.order)
|
||||||
|
|
Loading…
Reference in a new issue