Reccmp comparison engine refactor (#405)

* Reccmp comparison engine refactor

* Remove redundant references to 'entry' symbol
This commit is contained in:
MS 2024-01-04 18:12:55 -05:00 committed by GitHub
parent eeb980fa0f
commit ce68a7b1f4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 987 additions and 279 deletions

View file

@ -40,14 +40,6 @@
// LIBRARY: LEGO1 0x1008b640
// _rand
// entry
// LIBRARY: ISLE 0x4082e0
// _WinMainCRTStartup
// entry
// LIBRARY: LEGO1 0x1008c860
// __DllMainCRTStartup@12
// LIBRARY: ISLE 0x409110
// __mtinit

View file

@ -1,5 +1,4 @@
from .bin import *
from .dir import *
from .parser import *
from .syminfo import *
from .utils import *

View file

@ -1,3 +1,4 @@
import logging
import struct
from typing import List, Optional
from dataclasses import dataclass
@ -51,12 +52,17 @@ class ImageSectionHeader:
number_of_line_numbers: int
characteristics: int
@property
def extent(self):
"""Get the highest possible offset of this section"""
return max(self.size_of_raw_data, self.virtual_size)
def match_name(self, name: str) -> bool:
return self.name == struct.pack("8s", name.encode("ascii"))
def contains_vaddr(self, vaddr: int) -> bool:
ofs = vaddr - self.virtual_address
return 0 <= ofs < max(self.size_of_raw_data, self.virtual_size)
return 0 <= ofs < self.extent
def addr_is_uninitialized(self, vaddr: int) -> bool:
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
@ -71,25 +77,29 @@ def addr_is_uninitialized(self, vaddr: int) -> bool:
)
logger = logging.getLogger(__name__)
class Bin:
"""Parses a PE format EXE and allows reading data from a virtual address.
Reference: https://learn.microsoft.com/en-us/windows/win32/debug/pe-format"""
# pylint: disable=too-many-instance-attributes
def __init__(self, filename: str, logger=None) -> None:
self.logger = logger
self._debuglog(f'Parsing headers of "{filename}"... ')
def __init__(self, filename: str, find_str: bool = False) -> None:
logger.debug('Parsing headers of "%s"... ', filename)
self.filename = filename
self.file = None
self.imagebase = None
self.entry = None
self.sections: List[ImageSectionHeader] = []
self.last_section = None
self.find_str = find_str
self._potential_strings = {}
self._relocated_addrs = set()
def __enter__(self):
self._debuglog(f"Bin {self.filename} Enter")
logger.debug("Bin %s Enter", self.filename)
self.file = open(self.filename, "rb")
(mz_str,) = struct.unpack("2s", self.file.read(2))
@ -123,28 +133,71 @@ def __enter__(self):
self._populate_relocations()
# This is a (semi) expensive lookup that is not necesssary in every case.
# We can find strings in the original if we have coverage using STRING markers.
# For the recomp, we can find strings using the PDB.
if self.find_str:
self._prepare_string_search()
text_section = self._get_section_by_name(".text")
self.last_section = text_section
self._debuglog("... Parsing finished")
logger.debug("... Parsing finished")
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
self._debuglog(f"Bin {self.filename} Exit")
logger.debug("Bin %s Exit", self.filename)
if self.file:
self.file.close()
def _debuglog(self, msg):
"""Write to the logger, if present"""
if self.logger is not None:
self.logger.debug(msg)
def get_relocated_addresses(self) -> List[int]:
return sorted(self._relocated_addrs)
def find_string(self, target: str) -> Optional[int]:
# Pad with null terminator to make sure we don't
# match on a subset of the full string
if not target.endswith(b"\x00"):
target += b"\x00"
c = target[0]
if c not in self._potential_strings:
return None
for addr in self._potential_strings[c]:
if target == self.read(addr, len(target)):
return addr
return None
def is_relocated_addr(self, vaddr) -> bool:
return vaddr in self._relocated_addrs
def _prepare_string_search(self):
"""We are intersted in deduplicated string constants found in the
.rdata and .data sections. For each relocated address in these sections,
read the first byte and save the address if that byte is an ASCII character.
When we search for an arbitrary string later, we can narrow down the list
of potential locations by a lot."""
def is_ascii(b):
return b" " <= b < b"\x7f"
sect_data = self._get_section_by_name(".data")
sect_rdata = self._get_section_by_name(".rdata")
potentials = filter(
lambda a: sect_data.contains_vaddr(a) or sect_rdata.contains_vaddr(a),
self.get_relocated_addresses(),
)
for addr in potentials:
c = self.read(addr, 1)
if c is not None and is_ascii(c):
k = ord(c)
if k not in self._potential_strings:
self._potential_strings[k] = set()
self._potential_strings[k].add(addr)
def _populate_relocations(self):
"""The relocation table in .reloc gives each virtual address where the next four
bytes are, itself, another virtual address. During loading, these values will be
@ -212,6 +265,9 @@ def _get_section_by_name(self, name: str):
return section
def get_section_extent_by_index(self, index: int) -> int:
return self.sections[index - 1].extent
def get_section_offset_by_index(self, index: int) -> int:
"""The symbols output from cvdump gives addresses in this format: AAAA.BBBBBBBB
where A is the index (1-based) into the section table and B is the local offset.
@ -242,6 +298,15 @@ def get_raw_addr(self, vaddr: int) -> int:
+ self.last_section.pointer_to_raw_data
)
def is_valid_section(self, section: int) -> bool:
"""The PDB will refer to sections that are not listed in the headers
and so should ignore these references."""
try:
_ = self.sections[section - 1]
return True
except IndexError:
return False
def is_valid_vaddr(self, vaddr: int) -> bool:
"""Does this virtual address point to anything in the exe?"""
section = next(

View file

@ -0,0 +1 @@
from .core import Compare

View file

@ -0,0 +1,149 @@
import os
import logging
from typing import List, Optional
from isledecomp.cvdump.demangler import demangle_string_const
from isledecomp.cvdump import Cvdump, CvdumpAnalysis
from isledecomp.parser import DecompCodebase
from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType
from .db import CompareDb, MatchInfo
from .lines import LinesDb
logger = logging.getLogger(__name__)
class Compare:
# pylint: disable=too-many-instance-attributes
def __init__(self, orig_bin, recomp_bin, pdb_file, code_dir):
self.orig_bin = orig_bin
self.recomp_bin = recomp_bin
self.pdb_file = pdb_file
self.code_dir = code_dir
self._lines_db = LinesDb(code_dir)
self._db = CompareDb()
self._load_cvdump()
self._load_markers()
self._find_original_strings()
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
cv = (
Cvdump(self.pdb_file)
.lines()
.globals()
.publics()
.symbols()
.section_contributions()
.run()
)
res = CvdumpAnalysis(cv)
for sym in res.nodes:
# The PDB might contain sections that do not line up with the
# actual binary. The symbol "__except_list" is one example.
# In these cases, just skip this symbol and move on because
# we can't do much with it.
if not self.recomp_bin.is_valid_section(sym.section):
continue
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
# If this symbol is the final one in its section, we were not able to
# estimate its size because we didn't have the total size of that section.
# We can get this estimate now and assume that the final symbol occupies
# the remainder of the section.
if sym.estimated_size is None:
sym.estimated_size = (
self.recomp_bin.get_section_extent_by_index(sym.section)
- sym.offset
)
if sym.node_type == SymbolType.STRING:
string_info = demangle_string_const(sym.decorated_name)
# TODO: skip unicode for now. will need to handle these differently.
if string_info.is_utf16:
continue
raw = self.recomp_bin.read(addr, sym.size())
try:
sym.friendly_name = raw.decode("latin1")
except UnicodeDecodeError:
pass
self._db.set_recomp_symbol(addr, sym.node_type, sym.name(), sym.size())
for lineref in cv.lines:
addr = self.recomp_bin.get_abs_addr(lineref.section, lineref.offset)
self._lines_db.add_line(lineref.filename, lineref.line_no, addr)
# The _entry symbol is referenced in the PE header so we get this match for free.
self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)
def _load_markers(self):
# Guess at module name from PDB file name
# reccmp checks the original binary filename; we could use this too
(module, _) = os.path.splitext(os.path.basename(self.pdb_file))
codefiles = list(walk_source_dir(self.code_dir))
codebase = DecompCodebase(codefiles, module)
# Match lineref functions first because this is a guaranteed match.
# If we have two functions that share the same name, and one is
# a lineref, we can match the nameref correctly because the lineref
# was already removed from consideration.
for fun in codebase.iter_line_functions():
recomp_addr = self._lines_db.search_line(fun.filename, fun.line_number)
if recomp_addr is not None:
self._db.set_function_pair(fun.offset, recomp_addr)
if fun.should_skip():
self._db.skip_compare(fun.offset)
for fun in codebase.iter_name_functions():
self._db.match_function(fun.offset, fun.name)
if fun.should_skip():
self._db.skip_compare(fun.offset)
for var in codebase.iter_variables():
self._db.match_variable(var.offset, var.name)
for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name)
def _find_original_strings(self):
"""Go to the original binary and look for the specified string constants
to find a match. This is a (relatively) expensive operation so we only
look at strings that we have not already matched via a STRING annotation."""
for string in self._db.get_unmatched_strings():
addr = self.orig_bin.find_string(string.encode("latin1"))
if addr is None:
escaped = repr(string)
logger.debug("Failed to find this string in the original: %s", escaped)
continue
self._db.match_string(addr, string)
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
"""i.e. verbose mode for reccmp"""
return self._db.get_one_function(addr)
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches(SymbolType.FUNCTION)
def compare_functions(self):
pass
def compare_variables(self):
pass
def compare_pointers(self):
pass
def compare_strings(self):
pass
def compare_vtables(self):
pass

View file

@ -0,0 +1,149 @@
"""Wrapper for database (here an in-memory sqlite database) that collects the
addresses/symbols that we want to compare between the original and recompiled binaries."""
import sqlite3
import logging
from collections import namedtuple
from typing import List, Optional
from isledecomp.types import SymbolType
_SETUP_SQL = """
DROP TABLE IF EXISTS `symbols`;
CREATE TABLE `symbols` (
compare_type int,
orig_addr int,
recomp_addr int,
name text,
size int,
should_skip int default(FALSE)
);
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
CREATE INDEX `symbols_na` ON `symbols` (compare_type, name);
"""
MatchInfo = namedtuple("MatchInfo", "orig_addr, recomp_addr, size, name")
def matchinfo_factory(_, row):
return MatchInfo(*row)
logger = logging.getLogger(__name__)
class CompareDb:
def __init__(self):
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
def set_recomp_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
size: Optional[int],
):
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (recomp_addr, compare_type, name, size) VALUES (?,?,?,?)",
(addr, compare_value, name, size),
)
def get_unmatched_strings(self) -> List[str]:
"""Return any strings not already identified by STRING markers."""
cur = self._db.execute(
"SELECT name FROM `symbols` WHERE compare_type = ? AND orig_addr IS NULL",
(SymbolType.STRING.value,),
)
return [string for (string,) in cur.fetchall()]
def get_one_function(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name
FROM `symbols`
WHERE compare_type = ?
AND orig_addr = ?
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
ORDER BY orig_addr
""",
(SymbolType.FUNCTION.value, addr),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_matches(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute(
"""SELECT orig_addr, recomp_addr, size, name
FROM `symbols`
WHERE compare_type = ?
AND orig_addr IS NOT NULL
AND recomp_addr IS NOT NULL
AND should_skip IS FALSE
ORDER BY orig_addr
""",
(compare_type.value,),
)
cur.row_factory = matchinfo_factory
return cur.fetchall()
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
cur = self._db.execute(
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
(orig, SymbolType.FUNCTION.value, recomp),
)
return cur.rowcount > 0
# TODO: Both ways required?
def skip_compare(self, orig: int):
self._db.execute(
"UPDATE `symbols` SET should_skip = TRUE WHERE orig_addr = ?", (orig,)
)
def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool:
# Update the compare_type here too since the marker tells us what we should do
logger.debug("Looking for %s %s", compare_type.name.lower(), name)
cur = self._db.execute(
"""UPDATE `symbols`
SET orig_addr = ?, compare_type = ?
WHERE name = ?
AND orig_addr IS NULL
AND (compare_type = ? OR compare_type IS NULL)""",
(addr, compare_type.value, name, compare_type.value),
)
return cur.rowcount > 0
def match_function(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.FUNCTION, addr, name)
if not did_match:
logger.error("Failed to find function symbol with name: %s", name)
return did_match
def match_vtable(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.VTABLE, addr, name)
if not did_match:
logger.error("Failed to find vtable for class: %s", name)
return did_match
def match_variable(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on(
SymbolType.POINTER, addr, name
)
if not did_match:
logger.error("Failed to find variable: %s", name)
return did_match
def match_string(self, addr: int, value: str) -> bool:
did_match = self._match_on(SymbolType.STRING, addr, value)
if not did_match:
escaped = repr(value)
logger.error("Failed to find string: %s", escaped)

View file

@ -0,0 +1,58 @@
"""Database used to match (filename, line_number) pairs
between FUNCTION markers and PDB analysis."""
import sqlite3
import logging
from typing import Optional
from pathlib import Path
from isledecomp.dir import PathResolver
_SETUP_SQL = """
DROP TABLE IF EXISTS `lineref`;
CREATE TABLE `lineref` (
path text not null,
filename text not null,
line int not null,
addr int not null
);
CREATE INDEX `file_line` ON `lineref` (filename, line);
"""
logger = logging.getLogger(__name__)
class LinesDb:
def __init__(self, code_dir) -> None:
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
self._path_resolver = PathResolver(code_dir)
def add_line(self, path: str, line_no: int, addr: int):
"""To be added from the LINES section of cvdump."""
sourcepath = self._path_resolver.resolve_cvdump(path)
filename = Path(sourcepath).name.lower()
self._db.execute(
"INSERT INTO `lineref` (path, filename, line, addr) VALUES (?,?,?,?)",
(sourcepath, filename, line_no, addr),
)
def search_line(self, path: str, line_no: int) -> Optional[int]:
"""Using path and line number from FUNCTION marker,
get the address of this function in the recomp."""
filename = Path(path).name.lower()
cur = self._db.execute(
"SELECT path, addr FROM `lineref` WHERE filename = ? AND line = ?",
(filename, line_no),
)
for source_path, addr in cur.fetchall():
if Path(path).samefile(source_path):
return addr
logger.error(
"Failed to find function symbol with filename and line: %s:%d",
path,
line_no,
)
return None

View file

@ -1,2 +1,3 @@
from .analysis import CvdumpAnalysis
from .parser import CvdumpParser
from .runner import Cvdump

View file

@ -0,0 +1,184 @@
"""For collating the results from parsing cvdump.exe into a more directly useful format."""
from typing import List, Optional, Tuple
from isledecomp.types import SymbolType
from .parser import CvdumpParser
from .demangler import demangle_string_const, demangle_vtable
def data_type_info(type_name: str) -> Optional[Tuple[int, bool]]:
"""cvdump type aliases are listed here:
https://github.com/microsoft/microsoft-pdb/blob/master/include/cvinfo.h
For the given type, return tuple(size, is_pointer) if possible."""
# pylint: disable=too-many-return-statements
# TODO: refactor to be as simple as possble
# Ignore complex types. We can get the size of those from the TYPES section.
if not type_name.startswith("T"):
return None
# if 32-bit pointer
if type_name.startswith("T_32P"):
return (4, True)
if type_name.endswith("QUAD") or type_name.endswith("64"):
return (8, False)
if (
type_name.endswith("LONG")
or type_name.endswith("INT4")
or type_name.endswith("32")
):
return (4, False)
if type_name.endswith("SHORT") or type_name.endswith("WCHAR"):
return (2, False)
if "CHAR" in type_name:
return (1, False)
if type_name in ("T_NOTYPE", "T_VOID"):
return (0, False)
return None
class CvdumpNode:
# pylint: disable=too-many-instance-attributes
# These two are required and allow us to identify the symbol
section: int
offset: int
# aka the mangled name from the PUBLICS section
decorated_name: Optional[str] = None
# optional "nicer" name (e.g. of a function from SYMBOLS section)
friendly_name: Optional[str] = None
# To be determined by context after inserting data, unless the decorated
# name makes this obvious. (i.e. string constants or vtables)
# We choose not to assume that section 1 (probably ".text") contains only
# functions. Smacker functions are linked to their own section "_UNSTEXT"
node_type: Optional[SymbolType] = None
# Function size can be read from the LINES section so use this over any
# other value if we have it.
# TYPES section can tell us the size of structs and other complex types.
confirmed_size: Optional[int] = None
# Estimated by reading the distance between this symbol and the one that
# follows in the same section.
# If this is the last symbol in the section, we cannot estimate a size.
estimated_size: Optional[int] = None
# Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be
# accurate.
section_contribution: Optional[int] = None
def __init__(self, section: int, offset: int) -> None:
self.section = section
self.offset = offset
def set_decorated(self, name: str):
self.decorated_name = name
if self.decorated_name.startswith("??_7"):
self.node_type = SymbolType.VTABLE
self.friendly_name = demangle_vtable(self.decorated_name)
elif self.decorated_name.startswith("??_C@"):
self.node_type = SymbolType.STRING
(strlen, _) = demangle_string_const(self.decorated_name)
self.confirmed_size = strlen
elif not self.decorated_name.startswith("?") and "@" in self.decorated_name:
# C mangled symbol. The trailing at-sign with number tells the number of bytes
# in the parameter list for __stdcall, __fastcall, or __vectorcall
# For __cdecl it is more ambiguous and we would have to know which section we are in.
# https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170#FormatC
self.node_type = SymbolType.FUNCTION
def name(self) -> Optional[str]:
"""Prefer "friendly" name if we have it.
This is what we have been using to match functions."""
return self.friendly_name or self.decorated_name
def size(self) -> Optional[int]:
if self.confirmed_size is not None:
return self.confirmed_size
# Better to undershoot the size because we can identify a comparison gap easily
if self.estimated_size is not None and self.section_contribution is not None:
return min(self.estimated_size, self.section_contribution)
# Return whichever one we have, or neither
return self.estimated_size or self.section_contribution
class CvdumpAnalysis:
"""Collects the results from CvdumpParser into a list of nodes (i.e. symbols).
These can then be analyzed by a downstream tool."""
nodes = List[CvdumpNode]
def __init__(self, parser: CvdumpParser):
"""Read in as much information as we have from the parser.
The more sections we have, the better our information will be."""
node_dict = {}
# PUBLICS is our roadmap for everything that follows.
for pub in parser.publics:
key = (pub.section, pub.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
node_dict[key].set_decorated(pub.name)
for sizeref in parser.sizerefs:
key = (sizeref.section, sizeref.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
node_dict[key].section_contribution = sizeref.size
for glo in parser.globals:
key = (glo.section, glo.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
node_dict[key].node_type = SymbolType.DATA
node_dict[key].friendly_name = glo.name
if (g_info := data_type_info(glo.type)) is not None:
(size, is_pointer) = g_info
node_dict[key].confirmed_size = size
if is_pointer:
node_dict[key].node_type = SymbolType.POINTER
for lin in parser.lines:
key = (lin.section, lin.offset)
# Here we only set if the section:offset already exists
# because our values include offsets inside of the function.
if key in node_dict:
node_dict[key].node_type = SymbolType.FUNCTION
for sym in parser.symbols:
key = (sym.section, sym.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
if sym.type == "S_GPROC32":
node_dict[key].friendly_name = sym.name
node_dict[key].confirmed_size = sym.size
node_dict[key].node_type = SymbolType.FUNCTION
self.nodes = [v for _, v in dict(sorted(node_dict.items())).items()]
self._estimate_size()
def _estimate_size(self):
"""Get the distance between one section:offset value and the next one
in the same section. This gives a rough estimate of the size of the symbol.
If we have information from SECTION CONTRIBUTIONS, take whichever one is
less to get the best approximate size."""
for i in range(len(self.nodes) - 1):
this_node = self.nodes[i]
next_node = self.nodes[i + 1]
# If they are in different sections, we can't compare them
if this_node.section != next_node.section:
continue
this_node.estimated_size = next_node.offset - this_node.offset

View file

@ -0,0 +1,74 @@
"""For demangling a subset of MSVC mangled symbols.
Some unofficial information about the mangling scheme is here:
https://en.wikiversity.org/wiki/Visual_C%2B%2B_name_mangling
"""
import re
from collections import namedtuple
class InvalidEncodedNumberError(Exception):
pass
_encoded_number_translate = str.maketrans("ABCDEFGHIJKLMNOP", "0123456789ABCDEF")
def parse_encoded_number(string: str) -> int:
# TODO: assert string ends in "@"?
if string.endswith("@"):
string = string[:-1]
try:
return int(string.translate(_encoded_number_translate), 16)
except ValueError as e:
raise InvalidEncodedNumberError(string) from e
string_const_regex = re.compile(
r"\?\?_C@\_(?P<is_utf16>[0-1])(?P<len>\d|[A-P]+@)(?P<hash>\w+)@(?P<value>.+)@"
)
StringConstInfo = namedtuple("StringConstInfo", "len is_utf16")
def demangle_string_const(symbol: str) -> StringConstInfo:
"""Don't bother to decode the string text from the symbol.
We can just read it from the binary once we have the length."""
match = string_const_regex.match(symbol)
if match is None:
# See below
return StringConstInfo(0, False)
try:
strlen = (
parse_encoded_number(match.group("len"))
if "@" in match.group("len")
else int(match.group("len"))
)
except (ValueError, InvalidEncodedNumberError):
# This would be an annoying error to fail on if we get a bad symbol.
# For now, just assume a zero length string because this will probably
# raise some eyebrows during the comparison.
strlen = 0
is_utf16 = match.group("is_utf16") == "1"
return StringConstInfo(len=strlen, is_utf16=is_utf16)
def demangle_vtable(symbol: str) -> str:
"""Get the class name referenced in the vtable symbol."""
# Seek ahead 4 chars to strip off "??_7" prefix
t = symbol[4:].split("@")
# "?$" indicates a template class
if t[0].startswith("?$"):
class_name = t[0][2:]
# PA = Pointer/reference
# V or U = class or struct
if t[1].startswith("PA"):
generic = f"{t[1][3:]} *"
else:
generic = t[1][1:]
return f"{class_name}<{generic}>"
return t[0]

View file

@ -1,5 +1,5 @@
import re
from typing import Iterable
from typing import Iterable, Tuple
from collections import namedtuple
# e.g. `*** PUBLICS`
@ -40,17 +40,28 @@
)
LinesEntry = namedtuple("LinesEntry", "filename line_no addr")
# User functions only
LinesEntry = namedtuple("LinesEntry", "filename line_no section offset")
# Strings, vtables, functions
# superset of everything else
# only place you can find the C symbols (library functions, smacker, etc)
PublicsEntry = namedtuple("PublicsEntry", "type section offset flags name")
# S_GPROC32 = functions
SymbolsEntry = namedtuple("SymbolsEntry", "type section offset size name")
# (Estimated) size of any symbol
SizeRefEntry = namedtuple("SizeRefEntry", "section offset size")
# global variables
GdataEntry = namedtuple("GdataEntry", "section offset type name")
class CvdumpParser:
def __init__(self) -> None:
self._section: str = ""
self._lines_filename: str = ""
self._lines_function: Tuple[str, int] = ("", 0)
self.lines = []
self.publics = []
@ -64,19 +75,24 @@ def _lines_section(self, line: str):
we are in."""
# Subheader indicates a new function and possibly a new code filename.
# Save the section here because it is not given on the lines that follow.
if (match := _lines_subsection_header.match(line)) is not None:
self._lines_filename = match.group(1)
self._lines_function = (
match.group("filename"),
int(match.group("section"), 16),
)
return
if (matches := _line_addr_pairs_findall.findall(line)) is not None:
for line_no, addr in matches:
self.lines.append(
LinesEntry(
filename=self._lines_filename,
line_no=int(line_no),
addr=int(addr, 16),
)
# Match any pairs as we find them
for line_no, offset in _line_addr_pairs_findall.findall(line):
self.lines.append(
LinesEntry(
filename=self._lines_function[0],
line_no=int(line_no),
section=self._lines_function[1],
offset=int(offset, 16),
)
)
def _publics_section(self, line: str):
"""Match each line from PUBLICS and pull out the symbol information.

View file

@ -1,2 +1,3 @@
from .codebase import DecompCodebase
from .parser import DecompParser
from .linter import DecompLinter

View file

@ -0,0 +1,44 @@
"""For aggregating decomp markers read from an entire directory and for a single module."""
from typing import Iterable, Iterator, List
from .parser import DecompParser
from .node import (
ParserSymbol,
ParserFunction,
ParserVtable,
ParserVariable,
)
class DecompCodebase:
def __init__(self, filenames: Iterable[str], module: str) -> None:
self._symbols: List[ParserSymbol] = []
parser = DecompParser()
for filename in filenames:
parser.reset()
with open(filename, "r", encoding="utf-8") as f:
parser.read_lines(f)
for sym in parser.iter_symbols(module):
sym.filename = filename
self._symbols.append(sym)
def iter_line_functions(self) -> Iterator[ParserFunction]:
"""Return lineref functions separately from nameref. Assuming the PDB matches
the state of the source code, a line reference is a guaranteed match, even if
multiple functions share the same name. (i.e. polymorphism)"""
return filter(
lambda s: isinstance(s, ParserFunction) and not s.is_nameref(),
self._symbols,
)
def iter_name_functions(self) -> Iterator[ParserFunction]:
return filter(
lambda s: isinstance(s, ParserFunction) and s.is_nameref(), self._symbols
)
def iter_vtables(self) -> Iterator[ParserVtable]:
return filter(lambda s: isinstance(s, ParserVtable), self._symbols)
def iter_variables(self) -> Iterator[ParserVariable]:
return filter(lambda s: isinstance(s, ParserVariable), self._symbols)

View file

@ -1,6 +1,6 @@
# C++ file parser
from typing import List, Iterable, Iterator
from typing import List, Iterable, Iterator, Optional
from enum import Enum
from .util import (
is_blank_or_comment,
@ -122,6 +122,11 @@ def vtables(self) -> List[ParserSymbol]:
def variables(self) -> List[ParserSymbol]:
return [s for s in self._symbols if isinstance(s, ParserVariable)]
def iter_symbols(self, module: Optional[str] = None) -> Iterator[ParserSymbol]:
for s in self._symbols:
if module is None or s.module == module:
yield s
def _recover(self):
"""We hit a syntax error and need to reset temp structures"""
self.state = ReaderState.SEARCH

View file

@ -1,107 +0,0 @@
import os
from isledecomp.dir import PathResolver
from isledecomp.cvdump import Cvdump
class RecompiledInfo:
addr = None
size = None
name = None
start = None
# Declare a class that parses the output of cvdump for fast access later
class SymInfo:
funcs = {}
lines = {}
names = {}
def __init__(self, pdb, sym_recompfile, sym_logger, base_dir):
self.logger = sym_logger
path_resolver = PathResolver(base_dir)
self.logger.info("Parsing %s ...", pdb)
self.logger.debug("Parsing output of cvdump.exe ...")
cv = Cvdump(pdb).lines().symbols().publics().section_contributions().run()
self.logger.debug("... Parsing output of cvdump.exe finished")
contrib_dict = {(s.section, s.offset): s.size for s in cv.sizerefs}
for pub in cv.publics:
if pub.type == "S_PUB32" and (pub.section, pub.offset) in contrib_dict:
size = contrib_dict[(pub.section, pub.offset)]
info = RecompiledInfo()
info.addr = sym_recompfile.get_abs_addr(pub.section, pub.offset)
info.start = 0
info.size = size
info.name = pub.name
self.names[pub.name] = info
self.funcs[pub.offset] = info
for proc in cv.symbols:
if proc.type != "S_GPROC32":
continue
info = RecompiledInfo()
info.addr = sym_recompfile.get_abs_addr(proc.section, proc.offset)
info.start = 0
info.size = proc.size
info.name = proc.name
self.names[proc.name] = info
self.funcs[proc.offset] = info
for sourcepath, line_no, offset in cv.lines:
sourcepath = path_resolver.resolve_cvdump(sourcepath)
if sourcepath not in self.lines:
self.lines[sourcepath] = {}
if line_no not in self.lines[sourcepath]:
self.lines[sourcepath][line_no] = offset
def get_recompiled_address(self, filename, line):
recompiled_addr = None
self.logger.debug("Looking for %s:%s", filename, line)
filename_basename = os.path.basename(filename).lower()
for fn in self.lines:
# Sometimes a PDB is compiled with a relative path while we always have
# an absolute path. Therefore we must
try:
if os.path.basename(
fn
).lower() == filename_basename and os.path.samefile(fn, filename):
filename = fn
break
except FileNotFoundError:
continue
if filename in self.lines and line in self.lines[filename]:
recompiled_addr = self.lines[filename][line]
if recompiled_addr in self.funcs:
return self.funcs[recompiled_addr]
self.logger.error(
"Failed to find function symbol with address: %x", recompiled_addr
)
return None
self.logger.error(
"Failed to find function symbol with filename and line: %s:%s",
filename,
line,
)
return None
def get_recompiled_address_from_name(self, name):
self.logger.debug("Looking for %s", name)
if name in self.names:
return self.names[name]
self.logger.error("Failed to find function symbol with name: %s", name)
return None

View file

@ -0,0 +1,12 @@
"""Types shared by other modules"""
from enum import Enum
class SymbolType(Enum):
"""Broadly tells us what kind of comparison is required for this symbol."""
FUNCTION = 1
DATA = 2
POINTER = 3
STRING = 4
VTABLE = 5

View file

@ -0,0 +1,39 @@
import pytest
from isledecomp.cvdump.analysis import data_type_info
# fmt: off
type_check_cases = [
("T_32PINT4", 4, True),
("T_32PLONG", 4, True),
("T_32PRCHAR", 4, True),
("T_32PREAL32", 4, True),
("T_32PUCHAR", 4, True),
("T_32PUINT4", 4, True),
("T_32PULONG", 4, True),
("T_32PUSHORT", 4, True),
("T_32PVOID", 4, True),
("T_CHAR", 1, False),
("T_INT4", 4, False),
("T_LONG", 4, False),
("T_NOTYPE", 0, False), # ?
("T_QUAD", 8, False),
("T_RCHAR", 1, False),
("T_REAL32", 4, False),
("T_REAL64", 8, False),
("T_SHORT", 2, False),
("T_UCHAR", 1, False),
("T_UINT4", 4, False),
("T_ULONG", 4, False),
("T_UQUAD", 8, False),
("T_USHORT", 2, False),
("T_VOID", 0, False), # ?
("T_WCHAR", 2, False),
]
# fmt: on
@pytest.mark.parametrize("type_name, size, is_pointer", type_check_cases)
def test_type_check(type_name: str, size: int, is_pointer: bool):
assert (info := data_type_info(type_name)) is not None
assert info[0] == size
assert info[1] == is_pointer

View file

@ -0,0 +1,55 @@
import pytest
from isledecomp.cvdump.demangler import (
demangle_string_const,
demangle_vtable,
parse_encoded_number,
InvalidEncodedNumberError,
)
string_demangle_cases = [
("??_C@_08LIDF@December?$AA@", 8, False),
("??_C@_0L@EGPP@english?9nz?$AA@", 11, False),
(
"??_C@_1O@POHA@?$AA?$CI?$AAn?$AAu?$AAl?$AAl?$AA?$CJ?$AA?$AA?$AA?$AA?$AA?$AH?$AA?$AA?$AA?$AA?$AA?$AA?$AA?$9A?$AE?$;I@",
14,
True,
),
]
@pytest.mark.parametrize("symbol, strlen, is_utf16", string_demangle_cases)
def test_strings(symbol, is_utf16, strlen):
s = demangle_string_const(symbol)
assert s.len == strlen
assert s.is_utf16 == is_utf16
encoded_numbers = [
("A@", 0),
("AA@", 0), # would never happen?
("P@", 15),
("BA@", 16),
("BCD@", 291),
]
@pytest.mark.parametrize("string, value", encoded_numbers)
def test_encoded_numbers(string, value):
assert parse_encoded_number(string) == value
def test_invalid_encoded_number():
with pytest.raises(InvalidEncodedNumberError):
parse_encoded_number("Hello")
vtable_cases = [
("??_7LegoCarBuildAnimPresenter@@6B@", "LegoCarBuildAnimPresenter"),
("??_7?$MxCollection@PAVLegoWorld@@@@6B@", "MxCollection<LegoWorld *>"),
("??_7?$MxPtrList@VLegoPathController@@@@6B@", "MxPtrList<LegoPathController>"),
]
@pytest.mark.parametrize("symbol, class_name", vtable_cases)
def test_vtable(symbol, class_name):
assert demangle_vtable(symbol) == class_name

View file

@ -10,13 +10,11 @@
from isledecomp import (
Bin,
DecompParser,
get_file_in_script_dir,
OffsetPlaceholderGenerator,
print_diff,
SymInfo,
walk_source_dir,
)
from isledecomp.compare import Compare as IsleCompare
from capstone import Cs, CS_ARCH_X86, CS_MODE_32
import colorama
@ -261,7 +259,6 @@ def main():
args = parser.parse_args()
logging.basicConfig(level=args.loglevel, format="[%(levelname)s] %(message)s")
logger = logging.getLogger(__name__)
colorama.init()
@ -294,8 +291,13 @@ def main():
svg = args.svg
with Bin(original, logger) as origfile, Bin(recomp, logger) as recompfile:
syminfo = SymInfo(syms, recompfile, logger, source)
with Bin(original, find_str=True) as origfile, Bin(recomp) as recompfile:
if verbose is not None:
# Mute logger events from compare engine
logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL)
isle_compare = IsleCompare(origfile, recompfile, syms, source)
print()
@ -306,151 +308,120 @@ def main():
total_effective_accuracy = 0
htmlinsert = []
# Generate basename of original file, used in locating OFFSET lines
basename = os.path.basename(os.path.splitext(original)[0])
matches = []
if verbose is not None:
match = isle_compare.get_one_function(verbose)
if match is not None:
found_verbose_target = True
matches = [match]
else:
matches = isle_compare.get_functions()
parser = DecompParser()
for srcfilename in walk_source_dir(source):
parser.reset()
with open(srcfilename, "r", encoding="utf-8") as srcfile:
parser.read_lines(srcfile)
for match in matches:
# The effective_ratio is the ratio when ignoring differing register
# allocation vs the ratio is the true ratio.
ratio = 0.0
effective_ratio = 0.0
if match.size:
origasm = parse_asm(
capstone_disassembler,
origfile,
match.orig_addr,
match.size,
)
recompasm = parse_asm(
capstone_disassembler,
recompfile,
match.recomp_addr,
match.size,
)
for fun in parser.functions:
if fun.should_skip():
continue
diff = difflib.SequenceMatcher(None, origasm, recompasm)
ratio = diff.ratio()
effective_ratio = ratio
if fun.module != basename:
continue
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
if can_resolve_register_differences(origasm, recompasm):
effective_ratio = 1.0
else:
ratio = 0
addr = fun.offset
# Verbose flag handling
percenttext = f"{(effective_ratio * 100):.2f}%"
if not plain:
if effective_ratio == 1.0:
percenttext = (
colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL
)
elif effective_ratio > 0.8:
percenttext = (
colorama.Fore.YELLOW + percenttext + colorama.Style.RESET_ALL
)
else:
percenttext = (
colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL
)
if effective_ratio == 1.0 and ratio != 1.0:
if plain:
percenttext += "*"
else:
percenttext += colorama.Fore.RED + "*" + colorama.Style.RESET_ALL
if args.print_rec_addr:
addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}"
else:
addrs = hex(match.orig_addr)
if not verbose:
print(
f" {match.name} ({addrs}) is {percenttext} similar to the original"
)
function_count += 1
total_accuracy += ratio
total_effective_accuracy += effective_ratio
if match.size:
udiff = difflib.unified_diff(origasm, recompasm, n=10)
# If verbose, print the diff for that function to the output
if verbose:
if addr == verbose:
found_verbose_target = True
else:
continue
if fun.is_nameref():
recinfo = syminfo.get_recompiled_address_from_name(fun.name)
if not recinfo:
continue
else:
recinfo = syminfo.get_recompiled_address(
srcfilename, fun.line_number
)
if not recinfo:
continue
# The effective_ratio is the ratio when ignoring differing register
# allocation vs the ratio is the true ratio.
ratio = 0.0
effective_ratio = 0.0
if recinfo.size:
origasm = parse_asm(
capstone_disassembler,
origfile,
addr + recinfo.start,
recinfo.size,
)
recompasm = parse_asm(
capstone_disassembler,
recompfile,
recinfo.addr + recinfo.start,
recinfo.size,
)
diff = difflib.SequenceMatcher(None, origasm, recompasm)
ratio = diff.ratio()
effective_ratio = ratio
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
if can_resolve_register_differences(origasm, recompasm):
effective_ratio = 1.0
else:
ratio = 0
percenttext = f"{(effective_ratio * 100):.2f}%"
if not plain:
if effective_ratio == 1.0:
percenttext = (
colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL
)
elif effective_ratio > 0.8:
percenttext = (
colorama.Fore.YELLOW
+ percenttext
+ colorama.Style.RESET_ALL
)
else:
percenttext = (
colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL
)
if effective_ratio == 1.0 and ratio != 1.0:
if plain:
percenttext += "*"
else:
percenttext += (
colorama.Fore.RED + "*" + colorama.Style.RESET_ALL
)
if args.print_rec_addr:
addrs = f"0x{addr:x} / 0x{recinfo.addr:x}"
else:
addrs = hex(addr)
if not verbose:
print(
f" {recinfo.name} ({addrs}) is {percenttext} similar to the original"
)
function_count += 1
total_accuracy += ratio
total_effective_accuracy += effective_ratio
if recinfo.size:
udiff = difflib.unified_diff(origasm, recompasm, n=10)
# If verbose, print the diff for that function to the output
if verbose:
if effective_ratio == 1.0:
ok_text = (
"OK!"
if plain
else (
colorama.Fore.GREEN
+ "✨ OK! ✨"
+ colorama.Style.RESET_ALL
)
ok_text = (
"OK!"
if plain
else (
colorama.Fore.GREEN
+ "✨ OK! ✨"
+ colorama.Style.RESET_ALL
)
if ratio == 1.0:
print(
f"{addrs}: {recinfo.name} 100% match.\n\n{ok_text}\n\n"
)
else:
print(
f"{addrs}: {recinfo.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n"
)
)
if ratio == 1.0:
print(f"{addrs}: {match.name} 100% match.\n\n{ok_text}\n\n")
else:
print_diff(udiff, plain)
print(
f"\n{recinfo.name} is only {percenttext} similar to the original, diff above"
f"{addrs}: {match.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n"
)
else:
print_diff(udiff, plain)
# If html, record the diffs to an HTML file
if html_path:
htmlinsert.append(
{
"address": f"0x{addr:x}",
"name": recinfo.name,
"matching": effective_ratio,
"diff": "\n".join(udiff),
}
print(
f"\n{match.name} is only {percenttext} similar to the original, diff above"
)
# If html, record the diffs to an HTML file
if html_path:
htmlinsert.append(
{
"address": f"0x{match.orig_addr:x}",
"name": match.name,
"matching": effective_ratio,
"diff": "\n".join(udiff),
}
)
if html_path:
gen_html(html_path, json.dumps(htmlinsert))