From f26c30974a727e9162a8f3154ce4a8446102484e Mon Sep 17 00:00:00 2001 From: jonschz <17198703+jonschz@users.noreply.github.com> Date: Sun, 9 Jun 2024 14:41:24 +0200 Subject: [PATCH] Add Ghidra function import script (#909) * Add draft for Ghidra function import script * feature: Basic PDB analysis [skip ci] This is a draft with a lot of open questions left. Please do not merge * Refactor: Introduce submodules and reload remedy * refactor types and make them Python 3.9 compatible * run black * WIP: save progress * fix types and small type safety violations * fix another Python 3.9 syntax incompatibility * Implement struct imports [skip ci] - This code is still in dire need of refactoring and tests - There are only single-digit issues left, and 2600 functions can be imported - The biggest remaining error is mismatched stacks * Refactor, implement enums, fix lots of bugs * fix Python 3.9 issue * refactor: address review comments Not sure why VS Code suddenly decides to remove some empty spaces, but they don't make sense anyway * add unit tests for new type parsers, fix linter issue * refactor: db access from pdb_extraction.py * Fix stack layout offset error * fix: Undo incorrect reference change * Fix CI issue * Improve READMEs (fix typos, add information) --------- Co-authored-by: jonschz --- .gitignore | 1 + .pylintrc | 4 +- tools/README.md | 4 +- tools/ghidra_scripts/README.md | 25 ++ .../import_functions_and_types_from_pdb.py | 283 ++++++++++++++++ tools/ghidra_scripts/lego_util/__init__.py | 0 tools/ghidra_scripts/lego_util/exceptions.py | 47 +++ .../lego_util/function_importer.py | 241 ++++++++++++++ .../ghidra_scripts/lego_util/ghidra_helper.py | 100 ++++++ tools/ghidra_scripts/lego_util/headers.pyi | 19 ++ .../lego_util/pdb_extraction.py | 166 ++++++++++ tools/ghidra_scripts/lego_util/statistics.py | 68 ++++ .../ghidra_scripts/lego_util/type_importer.py | 313 ++++++++++++++++++ tools/isledecomp/isledecomp/compare/core.py | 17 +- tools/isledecomp/isledecomp/compare/db.py | 4 +- .../isledecomp/isledecomp/cvdump/__init__.py | 1 + .../isledecomp/isledecomp/cvdump/analysis.py | 14 +- tools/isledecomp/isledecomp/cvdump/parser.py | 31 +- tools/isledecomp/isledecomp/cvdump/symbols.py | 153 +++++++++ tools/isledecomp/isledecomp/cvdump/types.py | 283 +++++++++++++--- tools/isledecomp/tests/test_cvdump_types.py | 164 ++++++--- 21 files changed, 1824 insertions(+), 114 deletions(-) create mode 100644 tools/ghidra_scripts/README.md create mode 100644 tools/ghidra_scripts/import_functions_and_types_from_pdb.py create mode 100644 tools/ghidra_scripts/lego_util/__init__.py create mode 100644 tools/ghidra_scripts/lego_util/exceptions.py create mode 100644 tools/ghidra_scripts/lego_util/function_importer.py create mode 100644 tools/ghidra_scripts/lego_util/ghidra_helper.py create mode 100644 tools/ghidra_scripts/lego_util/headers.pyi create mode 100644 tools/ghidra_scripts/lego_util/pdb_extraction.py create mode 100644 tools/ghidra_scripts/lego_util/statistics.py create mode 100644 tools/ghidra_scripts/lego_util/type_importer.py create mode 100644 tools/isledecomp/isledecomp/cvdump/symbols.py diff --git a/.gitignore b/.gitignore index 7e16a6ce..d335e177 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ LEGO1.DLL LEGO1PROGRESS.* ISLEPROGRESS.* *.pyc +tools/ghidra_scripts/import.log diff --git a/.pylintrc b/.pylintrc index ab83fceb..976b3764 100644 --- a/.pylintrc +++ b/.pylintrc @@ -63,11 +63,11 @@ ignore-patterns=^\.# # (useful for modules/projects where namespaces are manipulated during runtime # and thus existing member attributes cannot be deduced by static analysis). It # supports qualified module names, as well as Unix pattern matching. -ignored-modules= +ignored-modules=ghidra # Python code to execute, usually for sys.path manipulation such as # pygtk.require(). -#init-hook= +init-hook='import sys; sys.path.append("tools/isledecomp"); sys.path.append("tools/ghidra_scripts")' # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the # number of processors available to use, and will cap the count on Windows to diff --git a/tools/README.md b/tools/README.md index 0c6b4112..0a998f2b 100644 --- a/tools/README.md +++ b/tools/README.md @@ -174,7 +174,7 @@ pip install -r tools/requirements.txt ## Testing -`isledecomp` comes with a suite of tests. Install `pylint` and run it, passing in the directory: +`isledecomp` comes with a suite of tests. Install `pytest` and run it, passing in the directory: ``` pip install pytest @@ -189,7 +189,7 @@ In order to keep the code clean and consistent, we use `pylint` and `black`: ### Run pylint (ignores build and virtualenv) -`pylint tools/ --ignore=build,bin,lib` +`pylint tools/ --ignore=build,ncc` ### Check code formatting without rewriting files diff --git a/tools/ghidra_scripts/README.md b/tools/ghidra_scripts/README.md new file mode 100644 index 00000000..1e5082d7 --- /dev/null +++ b/tools/ghidra_scripts/README.md @@ -0,0 +1,25 @@ +# Ghidra Scripts + +The scripts in this directory provide additional functionality in Ghidra, e.g. imports of symbols and types from the PDB debug symbol file. + +## Setup + +### Ghidrathon +Since these scripts and its dependencies are written in Python 3, [Ghidrathon](https://github.com/mandiant/Ghidrathon) must be installed first. Follow the instructions and install a recent build (these scripts were tested with Python 3.12 and Ghidrathon v4.0.0). + +### Script Directory +- In Ghidra, _Open Window -> Script Manager_. +- Click the _Manage Script Directories_ button on the top right. +- Click the _Add_ (Plus icon) button and select this file's parent directory. +- Close the window and click the _Refresh_ button. +- This script should now be available under the folder _LEGO1_. + +### Virtual environment +As of now, there must be a Python virtual environment set up under `$REPOSITORY_ROOT/.venv`, and the dependencies of `isledecomp` must be installed there, see [here](../README.md#tooling). + +## Development +- Type hints for Ghidra (optional): Download a recent release from https://github.com/VDOO-Connected-Trust/ghidra-pyi-generator, + unpack it somewhere, and `pip install` that directory in this virtual environment. This provides types and headers for Python. + Be aware that some of these files contain errors - in particular, `from typing import overload` seems to be missing everywhere, leading to spurious type errors. +- Note that the imported modules persist across multiple runs of the script (see [here](https://github.com/mandiant/Ghidrathon/issues/103)). + If you indend to modify an imported library, you have to use `import importlib; importlib.reload(${library})` or restart Ghidra for your changes to have any effect. Unfortunately, even that is not perfectly reliable, so you may still have to restart Ghidra for some changes in `isledecomp` to be applied. diff --git a/tools/ghidra_scripts/import_functions_and_types_from_pdb.py b/tools/ghidra_scripts/import_functions_and_types_from_pdb.py new file mode 100644 index 00000000..fcf5a7d3 --- /dev/null +++ b/tools/ghidra_scripts/import_functions_and_types_from_pdb.py @@ -0,0 +1,283 @@ +# Imports types and function signatures from debug symbols (PDB file) of the recompilation. +# +# This script uses Python 3 and therefore requires Ghidrathon to be installed in Ghidra (see https://github.com/mandiant/Ghidrathon). +# Furthermore, the virtual environment must be set up beforehand under $REPOSITORY_ROOT/.venv, and all required packages must be installed +# (see $REPOSITORY_ROOT/tools/README.md). +# Also, the Python version of the virtual environment must probably match the Python version used for Ghidrathon. + +# @author J. Schulz +# @category LEGO1 +# @keybinding +# @menupath +# @toolbar + + +# In order to make this code run both within and outside of Ghidra, the import order is rather unorthodox in this file. +# That is why some of the lints below are disabled. + +# pylint: disable=wrong-import-position,ungrouped-imports +# pylint: disable=undefined-variable # need to disable this one globally because pylint does not understand e.g. `askYesNo()`` + +# Disable spurious warnings in vscode / pylance +# pyright: reportMissingModuleSource=false + +import importlib +from dataclasses import dataclass, field +import logging.handlers +import sys +import logging +from pathlib import Path +import traceback +from typing import TYPE_CHECKING, Optional + + +if TYPE_CHECKING: + import ghidra + from lego_util.headers import * # pylint: disable=wildcard-import # these are just for headers + + +logger = logging.getLogger(__name__) + + +def reload_module(module: str): + """ + Due to a a quirk in Jep (used by Ghidrathon), imported modules persist for the lifetime of the Ghidra process + and are not reloaded when relaunching the script. Therefore, in order to facilitate development + we force reload all our own modules at startup. See also https://github.com/mandiant/Ghidrathon/issues/103. + + Note that as of 2024-05-30, this remedy does not work perfectly (yet): Some changes in isledecomp are + still not detected correctly and require a Ghidra restart to be applied. + """ + importlib.reload(importlib.import_module(module)) + + +reload_module("lego_util.statistics") +from lego_util.statistics import Statistics + + +@dataclass +class Globals: + verbose: bool + loglevel: int + running_from_ghidra: bool = False + # statistics + statistics: Statistics = field(default_factory=Statistics) + + +# hard-coded settings that we don't want to prompt in Ghidra every time +GLOBALS = Globals( + verbose=False, + # loglevel=logging.INFO, + loglevel=logging.DEBUG, +) + + +def setup_logging(): + logging.root.handlers.clear() + formatter = logging.Formatter("%(levelname)-8s %(message)s") + # formatter = logging.Formatter("%(name)s %(levelname)-8s %(message)s") # use this to identify loggers + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(formatter) + file_handler = logging.FileHandler( + Path(__file__).absolute().parent.joinpath("import.log"), mode="w" + ) + file_handler.setFormatter(formatter) + logging.root.setLevel(GLOBALS.loglevel) + logging.root.addHandler(stdout_handler) + logging.root.addHandler(file_handler) + logger.info("Starting import...") + + +# This script can be run both from Ghidra and as a standalone. +# In the latter case, only the PDB parser will be used. +setup_logging() +try: + from ghidra.program.flatapi import FlatProgramAPI + from ghidra.util.exception import CancelledException + + GLOBALS.running_from_ghidra = True +except ImportError as importError: + logger.error( + "Failed to import Ghidra functions, doing a dry run for the source code parser. " + "Has this script been launched from Ghidra?" + ) + logger.debug("Precise import error:", exc_info=importError) + + GLOBALS.running_from_ghidra = False + CancelledException = None + + +def get_repository_root(): + return Path(__file__).absolute().parent.parent.parent + + +def add_python_path(path: str): + """ + Scripts in Ghidra are executed from the tools/ghidra_scripts directory. We need to add + a few more paths to the Python path so we can import the other libraries. + """ + venv_path = get_repository_root().joinpath(path) + logger.info("Adding %s to Python Path", venv_path) + assert venv_path.exists() + sys.path.insert(1, str(venv_path)) + + +# We need to quote the types here because they might not exist when running without Ghidra +def import_function_into_ghidra( + api: "FlatProgramAPI", + match_info: "MatchInfo", + signature: "FunctionSignature", + type_importer: "PdbTypeImporter", +): + hex_original_address = f"{match_info.orig_addr:x}" + + # Find the Ghidra function at that address + ghidra_address = getAddressFactory().getAddress(hex_original_address) + # pylint: disable=possibly-used-before-assignment + function_importer = PdbFunctionImporter(api, match_info, signature, type_importer) + + ghidra_function = getFunctionAt(ghidra_address) + if ghidra_function is None: + ghidra_function = createFunction(ghidra_address, "temp") + assert ( + ghidra_function is not None + ), f"Failed to create function at {ghidra_address}" + logger.info("Created new function at %s", ghidra_address) + + logger.debug("Start handling function '%s'", function_importer.get_full_name()) + + if function_importer.matches_ghidra_function(ghidra_function): + logger.info( + "Skipping function '%s', matches already", + function_importer.get_full_name(), + ) + return + + logger.debug( + "Modifying function %s at 0x%s", + function_importer.get_full_name(), + hex_original_address, + ) + + function_importer.overwrite_ghidra_function(ghidra_function) + + GLOBALS.statistics.functions_changed += 1 + + +def process_functions(extraction: "PdbFunctionExtractor"): + func_signatures = extraction.get_function_list() + + if not GLOBALS.running_from_ghidra: + logger.info("Completed the dry run outside Ghidra.") + return + + api = FlatProgramAPI(currentProgram()) + # pylint: disable=possibly-used-before-assignment + type_importer = PdbTypeImporter(api, extraction) + + for match_info, signature in func_signatures: + try: + import_function_into_ghidra(api, match_info, signature, type_importer) + GLOBALS.statistics.successes += 1 + except Lego1Exception as e: + log_and_track_failure(match_info.name, e) + except RuntimeError as e: + cause = e.args[0] + if CancelledException is not None and isinstance(cause, CancelledException): + # let Ghidra's CancelledException pass through + logging.critical("Import aborted by the user.") + return + + log_and_track_failure(match_info.name, cause, unexpected=True) + logger.error(traceback.format_exc()) + except Exception as e: # pylint: disable=broad-exception-caught + log_and_track_failure(match_info.name, e, unexpected=True) + logger.error(traceback.format_exc()) + + +def log_and_track_failure( + function_name: Optional[str], error: Exception, unexpected: bool = False +): + if GLOBALS.statistics.track_failure_and_tell_if_new(error): + logger.error( + "%s(): %s%s", + function_name, + "Unexpected error: " if unexpected else "", + error, + ) + + +def main(): + repo_root = get_repository_root() + origfile_path = repo_root.joinpath("LEGO1.DLL") + build_path = repo_root.joinpath("build") + recompiledfile_path = build_path.joinpath("LEGO1.DLL") + pdb_path = build_path.joinpath("LEGO1.pdb") + + if not GLOBALS.verbose: + logging.getLogger("isledecomp.bin").setLevel(logging.WARNING) + logging.getLogger("isledecomp.compare.core").setLevel(logging.WARNING) + logging.getLogger("isledecomp.compare.db").setLevel(logging.WARNING) + logging.getLogger("isledecomp.compare.lines").setLevel(logging.WARNING) + logging.getLogger("isledecomp.cvdump.symbols").setLevel(logging.WARNING) + + logger.info("Starting comparison") + with Bin(str(origfile_path), find_str=True) as origfile, Bin( + str(recompiledfile_path) + ) as recompfile: + isle_compare = IsleCompare(origfile, recompfile, str(pdb_path), str(repo_root)) + + logger.info("Comparison complete.") + + # try to acquire matched functions + migration = PdbFunctionExtractor(isle_compare) + try: + process_functions(migration) + finally: + if GLOBALS.running_from_ghidra: + GLOBALS.statistics.log() + + logger.info("Done") + + +# sys.path is not reset after running the script, so we should restore it +sys_path_backup = sys.path.copy() +try: + # make modules installed in the venv available in Ghidra + add_python_path(".venv/Lib/site-packages") + # This one is needed when isledecomp is installed in editable mode in the venv + add_python_path("tools/isledecomp") + + import setuptools # pylint: disable=unused-import # required to fix a distutils issue in Python 3.12 + + reload_module("isledecomp") + from isledecomp import Bin + + reload_module("isledecomp.compare") + from isledecomp.compare import Compare as IsleCompare + + reload_module("isledecomp.compare.db") + from isledecomp.compare.db import MatchInfo + + reload_module("lego_util.exceptions") + from lego_util.exceptions import Lego1Exception + + reload_module("lego_util.pdb_extraction") + from lego_util.pdb_extraction import ( + PdbFunctionExtractor, + FunctionSignature, + ) + + if GLOBALS.running_from_ghidra: + reload_module("lego_util.ghidra_helper") + + reload_module("lego_util.function_importer") + from lego_util.function_importer import PdbFunctionImporter + + reload_module("lego_util.type_importer") + from lego_util.type_importer import PdbTypeImporter + + if __name__ == "__main__": + main() +finally: + sys.path = sys_path_backup diff --git a/tools/ghidra_scripts/lego_util/__init__.py b/tools/ghidra_scripts/lego_util/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tools/ghidra_scripts/lego_util/exceptions.py b/tools/ghidra_scripts/lego_util/exceptions.py new file mode 100644 index 00000000..1a92ba2a --- /dev/null +++ b/tools/ghidra_scripts/lego_util/exceptions.py @@ -0,0 +1,47 @@ +class Lego1Exception(Exception): + """ + Our own base class for exceptions. + Makes it easier to distinguish expected and unexpected errors. + """ + + +class TypeNotFoundError(Lego1Exception): + def __str__(self): + return f"Type not found in PDB: {self.args[0]}" + + +class TypeNotFoundInGhidraError(Lego1Exception): + def __str__(self): + return f"Type not found in Ghidra: {self.args[0]}" + + +class TypeNotImplementedError(Lego1Exception): + def __str__(self): + return f"Import not implemented for type: {self.args[0]}" + + +class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception): + def __init__(self, namespaceHierachy: list[str]): + super().__init__(namespaceHierachy) + + def get_namespace_str(self) -> str: + return "::".join(self.args[0]) + + def __str__(self): + return f"Class or namespace not found in Ghidra: {self.get_namespace_str()}" + + +class MultipleTypesFoundInGhidraError(Lego1Exception): + def __str__(self): + return ( + f"Found multiple types matching '{self.args[0]}' in Ghidra: {self.args[1]}" + ) + + +class StackOffsetMismatchError(Lego1Exception): + pass + + +class StructModificationError(Lego1Exception): + def __str__(self): + return f"Failed to modify struct in Ghidra: '{self.args[0]}'\nDetailed error: {self.__cause__}" diff --git a/tools/ghidra_scripts/lego_util/function_importer.py b/tools/ghidra_scripts/lego_util/function_importer.py new file mode 100644 index 00000000..e36db8bb --- /dev/null +++ b/tools/ghidra_scripts/lego_util/function_importer.py @@ -0,0 +1,241 @@ +# This file can only be imported successfully when run from Ghidra using Ghidrathon. + +# Disable spurious warnings in vscode / pylance +# pyright: reportMissingModuleSource=false + +import logging +from typing import Optional + +from ghidra.program.model.listing import Function, Parameter +from ghidra.program.flatapi import FlatProgramAPI +from ghidra.program.model.listing import ParameterImpl +from ghidra.program.model.symbol import SourceType + +from isledecomp.compare.db import MatchInfo + +from lego_util.pdb_extraction import ( + FunctionSignature, + CppRegisterSymbol, + CppStackSymbol, +) +from lego_util.ghidra_helper import ( + get_ghidra_namespace, + sanitize_name, +) + +from lego_util.exceptions import StackOffsetMismatchError +from lego_util.type_importer import PdbTypeImporter + + +logger = logging.getLogger(__name__) + + +# pylint: disable=too-many-instance-attributes +class PdbFunctionImporter: + """A representation of a function from the PDB with each type replaced by a Ghidra type instance.""" + + def __init__( + self, + api: FlatProgramAPI, + match_info: MatchInfo, + signature: FunctionSignature, + type_importer: "PdbTypeImporter", + ): + self.api = api + self.match_info = match_info + self.signature = signature + self.type_importer = type_importer + + if signature.class_type is not None: + # Import the base class so the namespace exists + self.type_importer.import_pdb_type_into_ghidra(signature.class_type) + + assert match_info.name is not None + + colon_split = sanitize_name(match_info.name).split("::") + self.name = colon_split.pop() + namespace_hierachy = colon_split + self.namespace = get_ghidra_namespace(api, namespace_hierachy) + + self.return_type = type_importer.import_pdb_type_into_ghidra( + signature.return_type + ) + self.arguments = [ + ParameterImpl( + f"param{index}", + type_importer.import_pdb_type_into_ghidra(type_name), + api.getCurrentProgram(), + ) + for (index, type_name) in enumerate(signature.arglist) + ] + + @property + def call_type(self): + return self.signature.call_type + + @property + def stack_symbols(self): + return self.signature.stack_symbols + + def get_full_name(self) -> str: + return f"{self.namespace.getName()}::{self.name}" + + def matches_ghidra_function(self, ghidra_function: Function) -> bool: + """Checks whether this function declaration already matches the description in Ghidra""" + name_match = self.name == ghidra_function.getName(False) + namespace_match = self.namespace == ghidra_function.getParentNamespace() + return_type_match = self.return_type == ghidra_function.getReturnType() + # match arguments: decide if thiscall or not + thiscall_matches = ( + self.signature.call_type == ghidra_function.getCallingConventionName() + ) + + if thiscall_matches: + if self.signature.call_type == "__thiscall": + args_match = self._matches_thiscall_parameters(ghidra_function) + else: + args_match = self._matches_non_thiscall_parameters(ghidra_function) + else: + args_match = False + + logger.debug( + "Matches: namespace=%s name=%s return_type=%s thiscall=%s args=%s", + namespace_match, + name_match, + return_type_match, + thiscall_matches, + args_match, + ) + + return ( + name_match + and namespace_match + and return_type_match + and thiscall_matches + and args_match + ) + + def _matches_non_thiscall_parameters(self, ghidra_function: Function) -> bool: + return self._parameter_lists_match(ghidra_function.getParameters()) + + def _matches_thiscall_parameters(self, ghidra_function: Function) -> bool: + ghidra_params = list(ghidra_function.getParameters()) + + # remove the `this` argument which we don't generate ourselves + ghidra_params.pop(0) + + return self._parameter_lists_match(ghidra_params) + + def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool: + if len(self.arguments) != len(ghidra_params): + logger.info("Mismatching argument count") + return False + + for this_arg, ghidra_arg in zip(self.arguments, ghidra_params): + # compare argument types + if this_arg.getDataType() != ghidra_arg.getDataType(): + logger.debug( + "Mismatching arg type: expected %s, found %s", + this_arg.getDataType(), + ghidra_arg.getDataType(), + ) + return False + # compare argument names + stack_match = self.get_matching_stack_symbol(ghidra_arg.getStackOffset()) + if stack_match is None: + logger.debug("Not found on stack: %s", ghidra_arg) + return False + # "__formal" is the placeholder for arguments without a name + if ( + stack_match.name != ghidra_arg.getName() + and not stack_match.name.startswith("__formal") + ): + logger.debug( + "Argument name mismatch: expected %s, found %s", + stack_match.name, + ghidra_arg.getName(), + ) + return False + return True + + def overwrite_ghidra_function(self, ghidra_function: Function): + """Replace the function declaration in Ghidra by the one derived from C++.""" + ghidra_function.setName(self.name, SourceType.USER_DEFINED) + ghidra_function.setParentNamespace(self.namespace) + ghidra_function.setReturnType(self.return_type, SourceType.USER_DEFINED) + ghidra_function.setCallingConvention(self.call_type) + + ghidra_function.replaceParameters( + Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS, + True, + SourceType.USER_DEFINED, + self.arguments, + ) + + # When we set the parameters, Ghidra will generate the layout. + # Now we read them again and match them against the stack layout in the PDB, + # both to verify and to set the parameter names. + ghidra_parameters: list[Parameter] = ghidra_function.getParameters() + + # Try to add Ghidra function names + for index, param in enumerate(ghidra_parameters): + if param.isStackVariable(): + self._rename_stack_parameter(index, param) + else: + if param.getName() == "this": + # 'this' parameters are auto-generated and cannot be changed + continue + + # Appears to never happen - could in theory be relevant to __fastcall__ functions, + # which we haven't seen yet + logger.warning("Unhandled register variable in %s", self.get_full_name) + continue + + def _rename_stack_parameter(self, index: int, param: Parameter): + match = self.get_matching_stack_symbol(param.getStackOffset()) + if match is None: + raise StackOffsetMismatchError( + f"Could not find a matching symbol at offset {param.getStackOffset()} in {self.get_full_name()}" + ) + + if match.data_type == "T_NOTYPE(0000)": + logger.warning("Skipping stack parameter of type NOTYPE") + return + + if param.getDataType() != self.type_importer.import_pdb_type_into_ghidra( + match.data_type + ): + logger.error( + "Type mismatch for parameter: %s in Ghidra, %s in PDB", param, match + ) + return + + name = match.name + if name == "__formal": + # these can cause name collisions if multiple ones are present + name = f"__formal_{index}" + + param.setName(name, SourceType.USER_DEFINED) + + def get_matching_stack_symbol(self, stack_offset: int) -> Optional[CppStackSymbol]: + return next( + ( + symbol + for symbol in self.stack_symbols + if isinstance(symbol, CppStackSymbol) + and symbol.stack_offset == stack_offset + ), + None, + ) + + def get_matching_register_symbol( + self, register: str + ) -> Optional[CppRegisterSymbol]: + return next( + ( + symbol + for symbol in self.stack_symbols + if isinstance(symbol, CppRegisterSymbol) and symbol.register == register + ), + None, + ) diff --git a/tools/ghidra_scripts/lego_util/ghidra_helper.py b/tools/ghidra_scripts/lego_util/ghidra_helper.py new file mode 100644 index 00000000..f7ea4ec7 --- /dev/null +++ b/tools/ghidra_scripts/lego_util/ghidra_helper.py @@ -0,0 +1,100 @@ +"""A collection of helper functions for the interaction with Ghidra.""" + +import logging + +from lego_util.exceptions import ( + ClassOrNamespaceNotFoundInGhidraError, + TypeNotFoundInGhidraError, + MultipleTypesFoundInGhidraError, +) + +# Disable spurious warnings in vscode / pylance +# pyright: reportMissingModuleSource=false + +from ghidra.program.model.data import PointerDataType +from ghidra.program.model.data import DataTypeConflictHandler +from ghidra.program.flatapi import FlatProgramAPI +from ghidra.program.model.data import DataType +from ghidra.program.model.symbol import Namespace + +logger = logging.getLogger(__name__) + + +def get_ghidra_type(api: FlatProgramAPI, type_name: str): + """ + Searches for the type named `typeName` in Ghidra. + + Raises: + - NotFoundInGhidraError + - MultipleTypesFoundInGhidraError + """ + result = api.getDataTypes(type_name) + if len(result) == 0: + raise TypeNotFoundInGhidraError(type_name) + if len(result) == 1: + return result[0] + + raise MultipleTypesFoundInGhidraError(type_name, result) + + +def add_pointer_type(api: FlatProgramAPI, pointee: DataType) -> DataType: + new_data_type = PointerDataType(pointee) + new_data_type.setCategoryPath(pointee.getCategoryPath()) + result_data_type = ( + api.getCurrentProgram() + .getDataTypeManager() + .addDataType(new_data_type, DataTypeConflictHandler.KEEP_HANDLER) + ) + if result_data_type is not new_data_type: + logger.debug( + "New pointer replaced by existing one. Fresh pointer: %s (class: %s)", + result_data_type, + result_data_type.__class__, + ) + return result_data_type + + +def get_ghidra_namespace( + api: FlatProgramAPI, namespace_hierachy: list[str] +) -> Namespace: + namespace = api.getCurrentProgram().getGlobalNamespace() + for part in namespace_hierachy: + namespace = api.getNamespace(namespace, part) + if namespace is None: + raise ClassOrNamespaceNotFoundInGhidraError(namespace_hierachy) + return namespace + + +def create_ghidra_namespace( + api: FlatProgramAPI, namespace_hierachy: list[str] +) -> Namespace: + namespace = api.getCurrentProgram().getGlobalNamespace() + for part in namespace_hierachy: + namespace = api.getNamespace(namespace, part) + if namespace is None: + namespace = api.createNamespace(namespace, part) + return namespace + + +def sanitize_name(name: str) -> str: + """ + Takes a full class or function name and replaces characters not accepted by Ghidra. + Applies mostly to templates and names like `vbase destructor`. + """ + new_class_name = ( + name.replace("<", "[") + .replace(">", "]") + .replace("*", "#") + .replace(" ", "_") + .replace("`", "'") + ) + if "<" in name: + new_class_name = "_template_" + new_class_name + + if new_class_name != name: + logger.warning( + "Class or function name contains characters forbidden by Ghidra, changing from '%s' to '%s'", + name, + new_class_name, + ) + return new_class_name diff --git a/tools/ghidra_scripts/lego_util/headers.pyi b/tools/ghidra_scripts/lego_util/headers.pyi new file mode 100644 index 00000000..89960443 --- /dev/null +++ b/tools/ghidra_scripts/lego_util/headers.pyi @@ -0,0 +1,19 @@ +from typing import TypeVar +import ghidra + +# pylint: disable=invalid-name,unused-argument + +T = TypeVar("T") + +# from ghidra.app.script.GhidraScript +def currentProgram() -> "ghidra.program.model.listing.Program": ... +def getAddressFactory() -> " ghidra.program.model.address.AddressFactory": ... +def state() -> "ghidra.app.script.GhidraState": ... +def askChoice(title: str, message: str, choices: list[T], defaultValue: T) -> T: ... +def askYesNo(title: str, question: str) -> bool: ... +def getFunctionAt( + entryPoint: ghidra.program.model.address.Address, +) -> ghidra.program.model.listing.Function: ... +def createFunction( + entryPoint: ghidra.program.model.address.Address, name: str +) -> ghidra.program.model.listing.Function: ... diff --git a/tools/ghidra_scripts/lego_util/pdb_extraction.py b/tools/ghidra_scripts/lego_util/pdb_extraction.py new file mode 100644 index 00000000..aaecc32d --- /dev/null +++ b/tools/ghidra_scripts/lego_util/pdb_extraction.py @@ -0,0 +1,166 @@ +from dataclasses import dataclass +import re +from typing import Any, Optional +import logging + +from isledecomp.cvdump.symbols import SymbolsEntry +from isledecomp.compare import Compare as IsleCompare +from isledecomp.compare.db import MatchInfo + +logger = logging.getLogger(__file__) + + +@dataclass +class CppStackOrRegisterSymbol: + name: str + data_type: str + + +@dataclass +class CppStackSymbol(CppStackOrRegisterSymbol): + stack_offset: int + """Should have a value iff `symbol_type=='S_BPREL32'.""" + + +@dataclass +class CppRegisterSymbol(CppStackOrRegisterSymbol): + register: str + """Should have a value iff `symbol_type=='S_REGISTER'.` Should always be set/converted to lowercase.""" + + +@dataclass +class FunctionSignature: + original_function_symbol: SymbolsEntry + call_type: str + arglist: list[str] + return_type: str + class_type: Optional[str] + stack_symbols: list[CppStackOrRegisterSymbol] + + +class PdbFunctionExtractor: + """ + Extracts all information on a given function from the parsed PDB + and prepares the data for the import in Ghidra. + """ + + def __init__(self, compare: IsleCompare): + self.compare = compare + + scalar_type_regex = re.compile(r"t_(?P\w+)(?:\((?P\d+)\))?") + + _call_type_map = { + "ThisCall": "__thiscall", + "C Near": "__thiscall", + "STD Near": "__stdcall", + } + + def _get_cvdump_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]: + return ( + None + if type_name is None + else self.compare.cv.types.keys.get(type_name.lower()) + ) + + def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]: + function_type_str = fn.func_type + if function_type_str == "T_NOTYPE(0000)": + logger.debug( + "Skipping a NOTYPE (synthetic or template + synthetic): %s", fn.name + ) + return None + + # get corresponding function type + + function_type = self.compare.cv.types.keys.get(function_type_str.lower()) + if function_type is None: + logger.error( + "Could not find function type %s for function %s", fn.func_type, fn.name + ) + return None + + class_type = function_type.get("class_type") + + arg_list_type = self._get_cvdump_type(function_type.get("arg_list_type")) + assert arg_list_type is not None + arg_list_pdb_types = arg_list_type.get("args", []) + assert arg_list_type["argcount"] == len(arg_list_pdb_types) + + stack_symbols: list[CppStackOrRegisterSymbol] = [] + + # for some unexplained reason, the reported stack is offset by 4 when this flag is set + stack_offset_delta = -4 if fn.frame_pointer_present else 0 + + for symbol in fn.stack_symbols: + if symbol.symbol_type == "S_REGISTER": + stack_symbols.append( + CppRegisterSymbol( + symbol.name, + symbol.data_type, + symbol.location, + ) + ) + elif symbol.symbol_type == "S_BPREL32": + stack_offset = int(symbol.location[1:-1], 16) + stack_symbols.append( + CppStackSymbol( + symbol.name, + symbol.data_type, + stack_offset + stack_offset_delta, + ) + ) + + call_type = self._call_type_map[function_type["call_type"]] + + return FunctionSignature( + original_function_symbol=fn, + call_type=call_type, + arglist=arg_list_pdb_types, + return_type=function_type["return_type"], + class_type=class_type, + stack_symbols=stack_symbols, + ) + + def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]: + handled = ( + self.handle_matched_function(match) + for match in self.compare.get_functions() + ) + return [signature for signature in handled if signature is not None] + + def handle_matched_function( + self, match_info: MatchInfo + ) -> Optional[tuple[MatchInfo, FunctionSignature]]: + assert match_info.orig_addr is not None + match_options = self.compare.get_match_options(match_info.orig_addr) + assert match_options is not None + if match_options.get("skip", False) or match_options.get("stub", False): + return None + + function_data = next( + ( + y + for y in self.compare.cvdump_analysis.nodes + if y.addr == match_info.recomp_addr + ), + None, + ) + if not function_data: + logger.error( + "Did not find function in nodes, skipping: %s", match_info.name + ) + return None + + function_symbol = function_data.symbol_entry + if function_symbol is None: + logger.debug( + "Could not find function symbol (likely a PUBLICS entry): %s", + match_info.name, + ) + return None + + function_signature = self.get_func_signature(function_symbol) + if function_signature is None: + return None + + return match_info, function_signature diff --git a/tools/ghidra_scripts/lego_util/statistics.py b/tools/ghidra_scripts/lego_util/statistics.py new file mode 100644 index 00000000..02232b01 --- /dev/null +++ b/tools/ghidra_scripts/lego_util/statistics.py @@ -0,0 +1,68 @@ +from dataclasses import dataclass, field +import logging + +from lego_util.exceptions import ( + TypeNotFoundInGhidraError, + ClassOrNamespaceNotFoundInGhidraError, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class Statistics: + functions_changed: int = 0 + successes: int = 0 + failures: dict[str, int] = field(default_factory=dict) + known_missing_types: dict[str, int] = field(default_factory=dict) + known_missing_namespaces: dict[str, int] = field(default_factory=dict) + + def track_failure_and_tell_if_new(self, error: Exception) -> bool: + """ + Adds the error to the statistics. Returns `False` if logging the error would be redundant + (e.g. because it is a `TypeNotFoundInGhidraError` with a type that has been logged before). + """ + error_type_name = error.__class__.__name__ + self.failures[error_type_name] = ( + self.failures.setdefault(error_type_name, 0) + 1 + ) + + if isinstance(error, TypeNotFoundInGhidraError): + return self._add_occurence_and_check_if_new( + self.known_missing_types, error.args[0] + ) + + if isinstance(error, ClassOrNamespaceNotFoundInGhidraError): + return self._add_occurence_and_check_if_new( + self.known_missing_namespaces, error.get_namespace_str() + ) + + # We do not have detailed tracking for other errors, so we want to log them every time + return True + + def _add_occurence_and_check_if_new(self, target: dict[str, int], key: str) -> bool: + old_count = target.setdefault(key, 0) + target[key] = old_count + 1 + return old_count == 0 + + def log(self): + logger.info("Statistics:\n~~~~~") + logger.info( + "Missing types (with number of occurences): %s\n~~~~~", + self.format_statistics(self.known_missing_types), + ) + logger.info( + "Missing classes/namespaces (with number of occurences): %s\n~~~~~", + self.format_statistics(self.known_missing_namespaces), + ) + logger.info("Successes: %d", self.successes) + logger.info("Failures: %s", self.failures) + logger.info("Functions changed: %d", self.functions_changed) + + def format_statistics(self, stats: dict[str, int]) -> str: + if len(stats) == 0: + return "" + return ", ".join( + f"{entry[0]} ({entry[1]})" + for entry in sorted(stats.items(), key=lambda x: x[1], reverse=True) + ) diff --git a/tools/ghidra_scripts/lego_util/type_importer.py b/tools/ghidra_scripts/lego_util/type_importer.py new file mode 100644 index 00000000..0d3ee5df --- /dev/null +++ b/tools/ghidra_scripts/lego_util/type_importer.py @@ -0,0 +1,313 @@ +import logging +from typing import Any + +# Disable spurious warnings in vscode / pylance +# pyright: reportMissingModuleSource=false + +# pylint: disable=too-many-return-statements # a `match` would be better, but for now we are stuck with Python 3.9 +# pylint: disable=no-else-return # Not sure why this rule even is a thing, this is great for checking exhaustiveness + +from lego_util.exceptions import ( + ClassOrNamespaceNotFoundInGhidraError, + TypeNotFoundError, + TypeNotFoundInGhidraError, + TypeNotImplementedError, + StructModificationError, +) +from lego_util.ghidra_helper import ( + add_pointer_type, + create_ghidra_namespace, + get_ghidra_namespace, + get_ghidra_type, + sanitize_name, +) +from lego_util.pdb_extraction import PdbFunctionExtractor + +from ghidra.program.flatapi import FlatProgramAPI +from ghidra.program.model.data import ( + ArrayDataType, + CategoryPath, + DataType, + DataTypeConflictHandler, + EnumDataType, + StructureDataType, + StructureInternal, +) +from ghidra.util.task import ConsoleTaskMonitor + + +logger = logging.getLogger(__name__) + + +class PdbTypeImporter: + """Allows PDB types to be imported into Ghidra.""" + + def __init__(self, api: FlatProgramAPI, extraction: PdbFunctionExtractor): + self.api = api + self.extraction = extraction + # tracks the structs/classes we have already started to import, otherwise we run into infinite recursion + self.handled_structs: set[str] = set() + self.struct_call_stack: list[str] = [] + + @property + def types(self): + return self.extraction.compare.cv.types + + def import_pdb_type_into_ghidra(self, type_index: str) -> DataType: + """ + Recursively imports a type from the PDB into Ghidra. + @param type_index Either a scalar type like `T_INT4(...)` or a PDB reference like `0x10ba` + """ + type_index_lower = type_index.lower() + if type_index_lower.startswith("t_"): + return self._import_scalar_type(type_index_lower) + + try: + type_pdb = self.extraction.compare.cv.types.keys[type_index_lower] + except KeyError as e: + raise TypeNotFoundError( + f"Failed to find referenced type '{type_index_lower}'" + ) from e + + type_category = type_pdb["type"] + + # follow forward reference (class, struct, union) + if type_pdb.get("is_forward_ref", False): + return self._import_forward_ref_type(type_index_lower, type_pdb) + + if type_category == "LF_POINTER": + return add_pointer_type( + self.api, self.import_pdb_type_into_ghidra(type_pdb["element_type"]) + ) + elif type_category in ["LF_CLASS", "LF_STRUCTURE"]: + return self._import_class_or_struct(type_pdb) + elif type_category == "LF_ARRAY": + return self._import_array(type_pdb) + elif type_category == "LF_ENUM": + return self._import_enum(type_pdb) + elif type_category == "LF_PROCEDURE": + logger.warning( + "Not implemented: Function-valued argument or return type will be replaced by void pointer: %s", + type_pdb, + ) + return get_ghidra_type(self.api, "void") + elif type_category == "LF_UNION": + return self._import_union(type_pdb) + else: + raise TypeNotImplementedError(type_pdb) + + _scalar_type_map = { + "rchar": "char", + "int4": "int", + "uint4": "uint", + "real32": "float", + "real64": "double", + } + + def _scalar_type_to_cpp(self, scalar_type: str) -> str: + if scalar_type.startswith("32p"): + return f"{self._scalar_type_to_cpp(scalar_type[3:])} *" + return self._scalar_type_map.get(scalar_type, scalar_type) + + def _import_scalar_type(self, type_index_lower: str) -> DataType: + if (match := self.extraction.scalar_type_regex.match(type_index_lower)) is None: + raise TypeNotFoundError(f"Type has unexpected format: {type_index_lower}") + + scalar_cpp_type = self._scalar_type_to_cpp(match.group("typename")) + return get_ghidra_type(self.api, scalar_cpp_type) + + def _import_forward_ref_type( + self, type_index, type_pdb: dict[str, Any] + ) -> DataType: + referenced_type = type_pdb.get("udt") or type_pdb.get("modifies") + if referenced_type is None: + try: + # Example: HWND__, needs to be created manually + return get_ghidra_type(self.api, type_pdb["name"]) + except TypeNotFoundInGhidraError as e: + raise TypeNotImplementedError( + f"{type_index}: forward ref without target, needs to be created manually: {type_pdb}" + ) from e + logger.debug( + "Following forward reference from %s to %s", + type_index, + referenced_type, + ) + return self.import_pdb_type_into_ghidra(referenced_type) + + def _import_array(self, type_pdb: dict[str, Any]) -> DataType: + inner_type = self.import_pdb_type_into_ghidra(type_pdb["array_type"]) + + array_total_bytes: int = type_pdb["size"] + data_type_size = inner_type.getLength() + array_length, modulus = divmod(array_total_bytes, data_type_size) + assert ( + modulus == 0 + ), f"Data type size {data_type_size} does not divide array size {array_total_bytes}" + + return ArrayDataType(inner_type, array_length, 0) + + def _import_union(self, type_pdb: dict[str, Any]) -> DataType: + try: + logger.debug("Dereferencing union %s", type_pdb) + union_type = get_ghidra_type(self.api, type_pdb["name"]) + assert ( + union_type.getLength() == type_pdb["size"] + ), f"Wrong size of existing union type '{type_pdb['name']}': expected {type_pdb['size']}, got {union_type.getLength()}" + return union_type + except TypeNotFoundInGhidraError as e: + # We have so few instances, it is not worth implementing this + raise TypeNotImplementedError( + f"Writing union types is not supported. Please add by hand: {type_pdb}" + ) from e + + def _import_enum(self, type_pdb: dict[str, Any]) -> DataType: + underlying_type = self.import_pdb_type_into_ghidra(type_pdb["underlying_type"]) + field_list = self.extraction.compare.cv.types.keys.get(type_pdb["field_type"]) + assert field_list is not None, f"Failed to find field list for enum {type_pdb}" + + result = EnumDataType( + CategoryPath("/imported"), type_pdb["name"], underlying_type.getLength() + ) + variants: list[dict[str, Any]] = field_list["variants"] + for variant in variants: + result.add(variant["name"], variant["value"]) + + return result + + def _import_class_or_struct(self, type_in_pdb: dict[str, Any]) -> DataType: + field_list_type: str = type_in_pdb["field_list_type"] + field_list = self.types.keys[field_list_type.lower()] + + class_size: int = type_in_pdb["size"] + class_name_with_namespace: str = sanitize_name(type_in_pdb["name"]) + + if class_name_with_namespace in self.handled_structs: + logger.debug( + "Class has been handled or is being handled: %s", + class_name_with_namespace, + ) + return get_ghidra_type(self.api, class_name_with_namespace) + + logger.debug( + "--- Beginning to import class/struct '%s'", class_name_with_namespace + ) + + # Add as soon as we start to avoid infinite recursion + self.handled_structs.add(class_name_with_namespace) + + self._get_or_create_namespace(class_name_with_namespace) + + data_type = self._get_or_create_struct_data_type( + class_name_with_namespace, class_size + ) + + if (old_size := data_type.getLength()) != class_size: + logger.warning( + "Existing class %s had incorrect size %d. Setting to %d...", + class_name_with_namespace, + old_size, + class_size, + ) + + logger.info("Adding class data type %s", class_name_with_namespace) + logger.debug("Class information: %s", type_in_pdb) + + data_type.deleteAll() + data_type.growStructure(class_size) + + # this case happened e.g. for IUnknown, which linked to an (incorrect) existing library, and some other types as well. + # Unfortunately, we don't get proper error handling for read-only types. + # However, we really do NOT want to do this every time because the type might be self-referential and partially imported. + if data_type.getLength() != class_size: + data_type = self._delete_and_recreate_struct_data_type( + class_name_with_namespace, class_size, data_type + ) + + # can be missing when no new fields are declared + components: list[dict[str, Any]] = field_list.get("members") or [] + + super_type = field_list.get("super") + if super_type is not None: + components.insert(0, {"type": super_type, "offset": 0, "name": "base"}) + + for component in components: + ghidra_type = self.import_pdb_type_into_ghidra(component["type"]) + logger.debug("Adding component to class: %s", component) + + try: + # for better logs + data_type.replaceAtOffset( + component["offset"], ghidra_type, -1, component["name"], None + ) + except Exception as e: + raise StructModificationError(type_in_pdb) from e + + logger.info("Finished importing class %s", class_name_with_namespace) + + return data_type + + def _get_or_create_namespace(self, class_name_with_namespace: str): + colon_split = class_name_with_namespace.split("::") + class_name = colon_split[-1] + try: + get_ghidra_namespace(self.api, colon_split) + logger.debug("Found existing class/namespace %s", class_name_with_namespace) + except ClassOrNamespaceNotFoundInGhidraError: + logger.info("Creating class/namespace %s", class_name_with_namespace) + class_name = colon_split.pop() + parent_namespace = create_ghidra_namespace(self.api, colon_split) + self.api.createClass(parent_namespace, class_name) + + def _get_or_create_struct_data_type( + self, class_name_with_namespace: str, class_size: int + ) -> StructureInternal: + try: + data_type = get_ghidra_type(self.api, class_name_with_namespace) + logger.debug( + "Found existing data type %s under category path %s", + class_name_with_namespace, + data_type.getCategoryPath(), + ) + except TypeNotFoundInGhidraError: + # Create a new struct data type + data_type = StructureDataType( + CategoryPath("/imported"), class_name_with_namespace, class_size + ) + data_type = ( + self.api.getCurrentProgram() + .getDataTypeManager() + .addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER) + ) + logger.info("Created new data type %s", class_name_with_namespace) + assert isinstance( + data_type, StructureInternal + ), f"Found type sharing its name with a class/struct, but is not a struct: {class_name_with_namespace}" + return data_type + + def _delete_and_recreate_struct_data_type( + self, + class_name_with_namespace: str, + class_size: int, + existing_data_type: DataType, + ) -> StructureInternal: + logger.warning( + "Failed to modify data type %s. Will try to delete the existing one and re-create the imported one.", + class_name_with_namespace, + ) + + assert ( + self.api.getCurrentProgram() + .getDataTypeManager() + .remove(existing_data_type, ConsoleTaskMonitor()) + ), f"Failed to delete and re-create data type {class_name_with_namespace}" + data_type = StructureDataType( + CategoryPath("/imported"), class_name_with_namespace, class_size + ) + data_type = ( + self.api.getCurrentProgram() + .getDataTypeManager() + .addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER) + ) + assert isinstance(data_type, StructureInternal) # for type checking + return data_type diff --git a/tools/isledecomp/isledecomp/compare/core.py b/tools/isledecomp/isledecomp/compare/core.py index b1f1f094..1587ef81 100644 --- a/tools/isledecomp/isledecomp/compare/core.py +++ b/tools/isledecomp/isledecomp/compare/core.py @@ -4,7 +4,7 @@ import struct import uuid from dataclasses import dataclass -from typing import Callable, Iterable, List, Optional +from typing import Any, Callable, Iterable, List, Optional from isledecomp.bin import Bin as IsleBin, InvalidVirtualAddressError from isledecomp.cvdump.demangler import demangle_string_const from isledecomp.cvdump import Cvdump, CvdumpAnalysis @@ -90,7 +90,7 @@ def __init__( def _load_cvdump(self): logger.info("Parsing %s ...", self.pdb_file) - cv = ( + self.cv = ( Cvdump(self.pdb_file) .lines() .globals() @@ -100,9 +100,9 @@ def _load_cvdump(self): .types() .run() ) - res = CvdumpAnalysis(cv) + self.cvdump_analysis = CvdumpAnalysis(self.cv) - for sym in res.nodes: + for sym in self.cvdump_analysis.nodes: # Skip nodes where we have almost no information. # These probably came from SECTION CONTRIBUTIONS. if sym.name() is None and sym.node_type is None: @@ -116,6 +116,7 @@ def _load_cvdump(self): continue addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset) + sym.addr = addr # If this symbol is the final one in its section, we were not able to # estimate its size because we didn't have the total size of that section. @@ -165,7 +166,10 @@ def _load_cvdump(self): addr, sym.node_type, sym.name(), sym.decorated_name, sym.size() ) - for (section, offset), (filename, line_no) in res.verified_lines.items(): + for (section, offset), ( + filename, + line_no, + ) in self.cvdump_analysis.verified_lines.items(): addr = self.recomp_bin.get_abs_addr(section, offset) self._lines_db.add_line(filename, line_no, addr) @@ -736,6 +740,9 @@ def get_vtables(self) -> List[MatchInfo]: def get_variables(self) -> List[MatchInfo]: return self._db.get_matches_by_type(SymbolType.DATA) + def get_match_options(self, addr: int) -> Optional[dict[str, Any]]: + return self._db.get_match_options(addr) + def compare_address(self, addr: int) -> Optional[DiffReport]: match = self._db.get_one_match(addr) if match is None: diff --git a/tools/isledecomp/isledecomp/compare/db.py b/tools/isledecomp/isledecomp/compare/db.py index 634cf455..99deb48e 100644 --- a/tools/isledecomp/isledecomp/compare/db.py +++ b/tools/isledecomp/isledecomp/compare/db.py @@ -2,7 +2,7 @@ addresses/symbols that we want to compare between the original and recompiled binaries.""" import sqlite3 import logging -from typing import List, Optional +from typing import Any, List, Optional from isledecomp.types import SymbolType from isledecomp.cvdump.demangler import get_vtordisp_name @@ -335,7 +335,7 @@ def mark_stub(self, orig: int): def skip_compare(self, orig: int): self._set_opt_bool(orig, "skip") - def get_match_options(self, addr: int) -> Optional[dict]: + def get_match_options(self, addr: int) -> Optional[dict[str, Any]]: cur = self._db.execute( """SELECT name, value FROM `match_options` WHERE addr = ?""", (addr,) ) diff --git a/tools/isledecomp/isledecomp/cvdump/__init__.py b/tools/isledecomp/isledecomp/cvdump/__init__.py index 8e1fd78a..334788c0 100644 --- a/tools/isledecomp/isledecomp/cvdump/__init__.py +++ b/tools/isledecomp/isledecomp/cvdump/__init__.py @@ -1,3 +1,4 @@ +from .symbols import SymbolsEntry from .analysis import CvdumpAnalysis from .parser import CvdumpParser from .runner import Cvdump diff --git a/tools/isledecomp/isledecomp/cvdump/analysis.py b/tools/isledecomp/isledecomp/cvdump/analysis.py index bd8734fa..40ef292e 100644 --- a/tools/isledecomp/isledecomp/cvdump/analysis.py +++ b/tools/isledecomp/isledecomp/cvdump/analysis.py @@ -1,5 +1,7 @@ """For collating the results from parsing cvdump.exe into a more directly useful format.""" + from typing import Dict, List, Tuple, Optional +from isledecomp.cvdump import SymbolsEntry from isledecomp.types import SymbolType from .parser import CvdumpParser from .demangler import demangle_string_const, demangle_vtable @@ -31,6 +33,8 @@ class CvdumpNode: # Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be # accurate. section_contribution: Optional[int] = None + addr: Optional[int] = None + symbol_entry: Optional[SymbolsEntry] = None def __init__(self, section: int, offset: int) -> None: self.section = section @@ -87,13 +91,12 @@ class CvdumpAnalysis: """Collects the results from CvdumpParser into a list of nodes (i.e. symbols). These can then be analyzed by a downstream tool.""" - nodes = List[CvdumpNode] - verified_lines = Dict[Tuple[str, str], Tuple[str, str]] + verified_lines: Dict[Tuple[str, str], Tuple[str, str]] def __init__(self, parser: CvdumpParser): """Read in as much information as we have from the parser. The more sections we have, the better our information will be.""" - node_dict = {} + node_dict: Dict[Tuple[int, int], CvdumpNode] = {} # PUBLICS is our roadmap for everything that follows. for pub in parser.publics: @@ -158,8 +161,11 @@ def __init__(self, parser: CvdumpParser): node_dict[key].friendly_name = sym.name node_dict[key].confirmed_size = sym.size node_dict[key].node_type = SymbolType.FUNCTION + node_dict[key].symbol_entry = sym - self.nodes = [v for _, v in dict(sorted(node_dict.items())).items()] + self.nodes: List[CvdumpNode] = [ + v for _, v in dict(sorted(node_dict.items())).items() + ] self._estimate_size() def _estimate_size(self): diff --git a/tools/isledecomp/isledecomp/cvdump/parser.py b/tools/isledecomp/isledecomp/cvdump/parser.py index 1b1eb3fd..c8f1d67d 100644 --- a/tools/isledecomp/isledecomp/cvdump/parser.py +++ b/tools/isledecomp/isledecomp/cvdump/parser.py @@ -2,6 +2,7 @@ from typing import Iterable, Tuple from collections import namedtuple from .types import CvdumpTypesParser +from .symbols import CvdumpSymbolsParser # e.g. `*** PUBLICS` _section_change_regex = re.compile(r"\*\*\* (?P
[A-Z/ ]{2,})") @@ -20,11 +21,6 @@ r"^(?P\w+): \[(?P
\w{4}):(?P\w{8})], Flags: (?P\w{8}), (?P\S+)" ) -# e.g. `(00008C) S_GPROC32: [0001:00034E90], Cb: 00000007, Type: 0x1024, ViewROI::IntrinsicImportance` -_symbol_line_regex = re.compile( - r"\(\w+\) (?P\S+): \[(?P
\w{4}):(?P\w{8})\], Cb: (?P\w+), Type:\s+\S+, (?P.+)" -) - # e.g. ` Debug start: 00000008, Debug end: 0000016E` _gproc_debug_regex = re.compile( r"\s*Debug start: (?P\w{8}), Debug end: (?P\w{8})" @@ -52,9 +48,6 @@ # only place you can find the C symbols (library functions, smacker, etc) PublicsEntry = namedtuple("PublicsEntry", "type section offset flags name") -# S_GPROC32 = functions -SymbolsEntry = namedtuple("SymbolsEntry", "type section offset size name") - # (Estimated) size of any symbol SizeRefEntry = namedtuple("SizeRefEntry", "module section offset size") @@ -72,12 +65,16 @@ def __init__(self) -> None: self.lines = {} self.publics = [] - self.symbols = [] self.sizerefs = [] self.globals = [] self.modules = [] self.types = CvdumpTypesParser() + self.symbols_parser = CvdumpSymbolsParser() + + @property + def symbols(self): + return self.symbols_parser.symbols def _lines_section(self, line: str): """Parsing entries from the LINES section. We only care about the pairs of @@ -127,20 +124,6 @@ def _globals_section(self, line: str): ) ) - def _symbols_section(self, line: str): - """We are interested in S_GPROC32 symbols only.""" - if (match := _symbol_line_regex.match(line)) is not None: - if match.group("type") == "S_GPROC32": - self.symbols.append( - SymbolsEntry( - type=match.group("type"), - section=int(match.group("section"), 16), - offset=int(match.group("offset"), 16), - size=int(match.group("size"), 16), - name=match.group("name"), - ) - ) - def _section_contributions(self, line: str): """Gives the size of elements across all sections of the binary. This is the easiest way to get the data size for .data and .rdata @@ -177,7 +160,7 @@ def read_line(self, line: str): self.types.read_line(line) elif self._section == "SYMBOLS": - self._symbols_section(line) + self.symbols_parser.read_line(line) elif self._section == "LINES": self._lines_section(line) diff --git a/tools/isledecomp/isledecomp/cvdump/symbols.py b/tools/isledecomp/isledecomp/cvdump/symbols.py new file mode 100644 index 00000000..22c1b32e --- /dev/null +++ b/tools/isledecomp/isledecomp/cvdump/symbols.py @@ -0,0 +1,153 @@ +from dataclasses import dataclass, field +import logging +import re +from re import Match +from typing import NamedTuple, Optional + + +logger = logging.getLogger(__name__) + + +class StackOrRegisterSymbol(NamedTuple): + symbol_type: str + location: str + """Should always be set/converted to lowercase.""" + data_type: str + name: str + + +# S_GPROC32 = functions +@dataclass +class SymbolsEntry: + # pylint: disable=too-many-instance-attributes + type: str + section: int + offset: int + size: int + func_type: str + name: str + stack_symbols: list[StackOrRegisterSymbol] = field(default_factory=list) + frame_pointer_present: bool = False + addr: Optional[int] = None # Absolute address. Will be set later, if at all + + +class CvdumpSymbolsParser: + _symbol_line_generic_regex = re.compile( + r"\(\w+\)\s+(?P[^\s:]+)(?::\s+(?P\S.*))?|(?::)$" + ) + """ + Parses the first part, e.g. `(00008C) S_GPROC32`, and splits off the second part after the colon (if it exists). + There are three cases: + - no colon, e.g. `(000350) S_END` + - colon but no data, e.g. `(000370) S_COMPILE:` + - colon and data, e.g. `(000304) S_REGISTER: esi, Type: 0x1E14, this`` + """ + + _symbol_line_function_regex = re.compile( + r"\[(?P
\w{4}):(?P\w{8})\], Cb: (?P\w+), Type:\s+(?P[^\s,]+), (?P.+)" + ) + """ + Parses the second part of a function symbol, e.g. + `[0001:00034E90], Cb: 00000007, Type: 0x1024, ViewROI::IntrinsicImportance` + """ + + # the second part of e.g. + _stack_register_symbol_regex = re.compile( + r"(?P\S+), Type:\s+(?P[\w()]+), (?P.+)$" + ) + """ + Parses the second part of a stack or register symbol, e.g. + `esi, Type: 0x1E14, this` + """ + + _debug_start_end_regex = re.compile( + r"^\s*Debug start: (?P\w+), Debug end: (?P\w+)$" + ) + + _parent_end_next_regex = re.compile( + r"\s*Parent: (?P\w+), End: (?P\w+), Next: (?P\w+)$" + ) + + _flags_frame_pointer_regex = re.compile(r"\s*Flags: Frame Ptr Present$") + + _register_stack_symbols = ["S_BPREL32", "S_REGISTER"] + + # List the unhandled types so we can check exhaustiveness + _unhandled_symbols = [ + "S_COMPILE", + "S_OBJNAME", + "S_THUNK32", + "S_LABEL32", + "S_LDATA32", + "S_LPROC32", + "S_UDT", + ] + + """Parser for cvdump output, SYMBOLS section.""" + + def __init__(self): + self.symbols: list[SymbolsEntry] = [] + self.current_function: Optional[SymbolsEntry] = None + + def read_line(self, line: str): + if (match := self._symbol_line_generic_regex.match(line)) is not None: + self._parse_generic_case(line, match) + elif (match := self._parent_end_next_regex.match(line)) is not None: + # We do not need this info at the moment, might be useful in the future + pass + elif (match := self._debug_start_end_regex.match(line)) is not None: + # We do not need this info at the moment, might be useful in the future + pass + elif (match := self._flags_frame_pointer_regex.match(line)) is not None: + if self.current_function is None: + logger.error( + "Found a `Flags: Frame Ptr Present` but self.current_function is None" + ) + return + self.current_function.frame_pointer_present = True + else: + # Most of these are either `** Module: [...]` or data we do not care about + logger.debug("Unhandled line: %s", line[:-1]) + + def _parse_generic_case(self, line, line_match: Match[str]): + symbol_type: str = line_match.group("symbol_type") + second_part: Optional[str] = line_match.group("second_part") + + if symbol_type == "S_GPROC32": + assert second_part is not None + if (match := self._symbol_line_function_regex.match(second_part)) is None: + logger.error("Invalid function symbol: %s", line[:-1]) + return + self.current_function = SymbolsEntry( + type=symbol_type, + section=int(match.group("section"), 16), + offset=int(match.group("offset"), 16), + size=int(match.group("size"), 16), + func_type=match.group("func_type"), + name=match.group("name"), + ) + self.symbols.append(self.current_function) + + elif symbol_type in self._register_stack_symbols: + assert second_part is not None + if self.current_function is None: + logger.error("Found stack/register outside of function: %s", line[:-1]) + return + if (match := self._stack_register_symbol_regex.match(second_part)) is None: + logger.error("Invalid stack/register symbol: %s", line[:-1]) + return + + new_symbol = StackOrRegisterSymbol( + symbol_type=symbol_type, + location=match.group("location").lower(), + data_type=match.group("data_type"), + name=match.group("name"), + ) + self.current_function.stack_symbols.append(new_symbol) + + elif symbol_type == "S_END": + self.current_function = None + elif symbol_type in self._unhandled_symbols: + return + else: + logger.error("Unhandled symbol type: %s", line) diff --git a/tools/isledecomp/isledecomp/cvdump/types.py b/tools/isledecomp/isledecomp/cvdump/types.py index 547d3ce9..381c27e9 100644 --- a/tools/isledecomp/isledecomp/cvdump/types.py +++ b/tools/isledecomp/isledecomp/cvdump/types.py @@ -1,5 +1,9 @@ import re -from typing import Dict, List, NamedTuple, Optional +import logging +from typing import Any, Dict, List, NamedTuple, Optional + + +logger = logging.getLogger(__name__) class CvdumpTypeError(Exception): @@ -42,7 +46,7 @@ def is_pointer(self) -> bool: class TypeInfo(NamedTuple): key: str - size: int + size: Optional[int] name: Optional[str] = None members: Optional[List[FieldListItem]] = None @@ -156,6 +160,10 @@ class CvdumpTypesParser: # LF_FIELDLIST member name (2/2) MEMBER_RE = re.compile(r"^\s+member name = '(?P.*)'$") + LF_FIELDLIST_ENUMERATE = re.compile( + r"^\s+list\[\d+\] = LF_ENUMERATE,.*value = (?P\d+), name = '(?P[^']+)'$" + ) + # LF_ARRAY element type ARRAY_ELEMENT_RE = re.compile(r"^\s+Element type = (?P.*)") @@ -169,12 +177,53 @@ class CvdumpTypesParser: # LF_CLASS/LF_STRUCTURE name and other info CLASS_NAME_RE = re.compile( - r"^\s+Size = (?P\d+), class name = (?P.+), UDT\((?P0x\w+)\)" + r"^\s+Size = (?P\d+), class name = (?P(?:[^,]|,\S)+)(?:, UDT\((?P0x\w+)\))?" ) # LF_MODIFIER, type being modified MODIFIES_RE = re.compile(r".*modifies type (?P.*)$") + # LF_ARGLIST number of entries + LF_ARGLIST_ARGCOUNT = re.compile(r".*argument count = (?P\d+)$") + + # LF_ARGLIST list entry + LF_ARGLIST_ENTRY = re.compile( + r"^\s+list\[(?P\d+)\] = (?P[\w()]+)$" + ) + + # LF_POINTER element + LF_POINTER_ELEMENT = re.compile(r"^\s+Element type : (?P.+)$") + + # LF_MFUNCTION attribute key-value pairs + LF_MFUNCTION_ATTRIBUTES = [ + re.compile(r"\s*Return type = (?P[\w()]+)$"), + re.compile(r"\s*Class type = (?P[\w()]+)$"), + re.compile(r"\s*This type = (?P[\w()]+)$"), + # Call type may contain whitespace + re.compile(r"\s*Call type = (?P[\w()\s]+)$"), + re.compile(r"\s*Parms = (?P[\w()]+)$"), # LF_MFUNCTION only + re.compile(r"\s*# Parms = (?P[\w()]+)$"), # LF_PROCEDURE only + re.compile(r"\s*Arg list type = (?P[\w()]+)$"), + re.compile( + r"\s*This adjust = (?P[\w()]+)$" + ), # TODO: figure out the meaning + re.compile( + r"\s*Func attr = (?P[\w()]+)$" + ), # Only for completeness, is always `none` + ] + + LF_ENUM_ATTRIBUTES = [ + re.compile(r"^\s*# members = (?P\d+)$"), + re.compile(r"^\s*enum name = (?P.+)$"), + ] + LF_ENUM_TYPES = re.compile( + r"^\s*type = (?P\S+) field list type (?P0x\w{4})$" + ) + LF_ENUM_UDT = re.compile(r"^\s*UDT\((?P0x\w+)\)$") + LF_UNION_LINE = re.compile( + r"^.*field list type (?P0x\w+),.*Size = (?P\d+)\s*,class name = (?P(?:[^,]|,\S)+),\s.*UDT\((?P0x\w+)\)$" + ) + MODES_OF_INTEREST = { "LF_ARRAY", "LF_CLASS", @@ -183,12 +232,16 @@ class CvdumpTypesParser: "LF_MODIFIER", "LF_POINTER", "LF_STRUCTURE", + "LF_ARGLIST", + "LF_MFUNCTION", + "LF_PROCEDURE", + "LF_UNION", } def __init__(self) -> None: self.mode: Optional[str] = None self.last_key = "" - self.keys = {} + self.keys: Dict[str, Dict[str, Any]] = {} def _new_type(self): """Prepare a new dict for the type we just parsed. @@ -211,13 +264,20 @@ def _set_member_name(self, name: str): obj = self.keys[self.last_key] obj["members"][-1]["name"] = name - def _get_field_list(self, type_obj: Dict) -> List[FieldListItem]: + def _add_variant(self, name: str, value: int): + obj = self.keys[self.last_key] + if "variants" not in obj: + obj["variants"] = [] + variants: list[dict[str, Any]] = obj["variants"] + variants.append({"name": name, "value": value}) + + def _get_field_list(self, type_obj: Dict[str, Any]) -> List[FieldListItem]: """Return the field list for the given LF_CLASS/LF_STRUCTURE reference""" if type_obj.get("type") == "LF_FIELDLIST": field_obj = type_obj else: - field_list_type = type_obj.get("field_list_type") + field_list_type = type_obj["field_list_type"] field_obj = self.keys[field_list_type] members: List[FieldListItem] = [] @@ -253,6 +313,9 @@ def _mock_array_members(self, type_obj: Dict) -> List[FieldListItem]: raise CvdumpIntegrityError("No array element type") array_element_size = self.get(array_type).size + assert ( + array_element_size is not None + ), "Encountered an array whose type has no size" n_elements = type_obj["size"] // array_element_size @@ -285,7 +348,10 @@ def get(self, type_key: str) -> TypeInfo: # These type references are just a wrapper around a scalar if obj.get("type") == "LF_ENUM": - return self.get("T_INT4") + underlying_type = obj.get("underlying_type") + if underlying_type is None: + raise CvdumpKeyError(f"Missing 'underlying_type' in {obj}") + return self.get(underlying_type) if obj.get("type") == "LF_POINTER": return self.get("T_32PVOID") @@ -350,6 +416,9 @@ def get_scalars_gapless(self, type_key: str) -> List[ScalarType]: obj = self.get(type_key) total_size = obj.size + assert ( + total_size is not None + ), "Called get_scalar_gapless() on a type without size" scalars = self.get_scalars(type_key) @@ -383,6 +452,11 @@ def get_format_string(self, type_key: str) -> str: return member_list_to_struct_string(members) def read_line(self, line: str): + if line.endswith("\n"): + line = line[:-1] + if len(line) == 0: + return + if (match := self.INDEX_RE.match(line)) is not None: type_ = match.group(2) if type_ not in self.MODES_OF_INTEREST: @@ -393,6 +467,12 @@ def read_line(self, line: str): self.last_key = match.group(1) self.mode = type_ self._new_type() + + if type_ == "LF_ARGLIST": + submatch = self.LF_ARGLIST_ARGCOUNT.match(line) + assert submatch is not None + self.keys[self.last_key]["argcount"] = int(submatch.group("argcount")) + # TODO: This should be validated in another pass return if self.mode is None: @@ -413,41 +493,170 @@ def read_line(self, line: str): self._set("size", int(match.group("length"))) elif self.mode == "LF_FIELDLIST": - # If this class has a vtable, create a mock member at offset 0 - if (match := self.VTABLE_RE.match(line)) is not None: - # For our purposes, any pointer type will do - self._add_member(0, "T_32PVOID") - self._set_member_name("vftable") + self.read_fieldlist_line(line) - # Superclass is set here in the fieldlist rather than in LF_CLASS - elif (match := self.SUPERCLASS_RE.match(line)) is not None: - self._set("super", normalize_type_id(match.group("type"))) + elif self.mode == "LF_ARGLIST": + self.read_arglist_line(line) - # Member offset and type given on the first of two lines. - elif (match := self.LIST_RE.match(line)) is not None: - self._add_member( - int(match.group("offset")), normalize_type_id(match.group("type")) - ) + elif self.mode in ["LF_MFUNCTION", "LF_PROCEDURE"]: + self.read_mfunction_line(line) - # Name of the member read on the second of two lines. - elif (match := self.MEMBER_RE.match(line)) is not None: - self._set_member_name(match.group("name")) + elif self.mode in ["LF_CLASS", "LF_STRUCTURE"]: + self.read_class_or_struct_line(line) - else: # LF_CLASS or LF_STRUCTURE - # Match the reference to the associated LF_FIELDLIST - if (match := self.CLASS_FIELD_RE.match(line)) is not None: - if match.group("field_type") == "0x0000": - # Not redundant. UDT might not match the key. - # These cases get reported as UDT mismatch. - self._set("is_forward_ref", True) - else: - field_list_type = normalize_type_id(match.group("field_type")) - self._set("field_list_type", field_list_type) + elif self.mode == "LF_POINTER": + self.read_pointer_line(line) + elif self.mode == "LF_ENUM": + self.read_enum_line(line) + + elif self.mode == "LF_UNION": + self.read_union_line(line) + + else: + # Check for exhaustiveness + logger.error("Unhandled data in mode: %s", self.mode) + + def read_fieldlist_line(self, line: str): + # If this class has a vtable, create a mock member at offset 0 + if (match := self.VTABLE_RE.match(line)) is not None: + # For our purposes, any pointer type will do + self._add_member(0, "T_32PVOID") + self._set_member_name("vftable") + + # Superclass is set here in the fieldlist rather than in LF_CLASS + elif (match := self.SUPERCLASS_RE.match(line)) is not None: + self._set("super", normalize_type_id(match.group("type"))) + + # Member offset and type given on the first of two lines. + elif (match := self.LIST_RE.match(line)) is not None: + self._add_member( + int(match.group("offset")), normalize_type_id(match.group("type")) + ) + + # Name of the member read on the second of two lines. + elif (match := self.MEMBER_RE.match(line)) is not None: + self._set_member_name(match.group("name")) + + elif (match := self.LF_FIELDLIST_ENUMERATE.match(line)) is not None: + self._add_variant(match.group("name"), int(match.group("value"))) + + def read_class_or_struct_line(self, line: str): + # Match the reference to the associated LF_FIELDLIST + if (match := self.CLASS_FIELD_RE.match(line)) is not None: + if match.group("field_type") == "0x0000": + # Not redundant. UDT might not match the key. + # These cases get reported as UDT mismatch. + self._set("is_forward_ref", True) + else: + field_list_type = normalize_type_id(match.group("field_type")) + self._set("field_list_type", field_list_type) + + elif line.lstrip().startswith("Derivation list type"): + # We do not care about the second line, but we still match it so we see an error + # when another line fails to match + pass + elif (match := self.CLASS_NAME_RE.match(line)) is not None: # Last line has the vital information. # If this is a FORWARD REF, we need to follow the UDT pointer # to get the actual class details. - elif (match := self.CLASS_NAME_RE.match(line)) is not None: - self._set("name", match.group("name")) - self._set("udt", normalize_type_id(match.group("udt"))) - self._set("size", int(match.group("size"))) + self._set("name", match.group("name")) + udt = match.group("udt") + if udt is not None: + self._set("udt", normalize_type_id(udt)) + self._set("size", int(match.group("size"))) + else: + logger.error("Unmatched line in class: %s", line[:-1]) + + def read_arglist_line(self, line: str): + if (match := self.LF_ARGLIST_ENTRY.match(line)) is not None: + obj = self.keys[self.last_key] + arglist: list = obj.setdefault("args", []) + assert int(match.group("index")) == len( + arglist + ), "Argument list out of sync" + arglist.append(match.group("arg_type")) + else: + logger.error("Unmatched line in arglist: %s", line[:-1]) + + def read_pointer_line(self, line): + if (match := self.LF_POINTER_ELEMENT.match(line)) is not None: + self._set("element_type", match.group("element_type")) + else: + stripped_line = line.strip() + # We don't parse these lines, but we still want to check for exhaustiveness + # in case we missed some relevant data + if not any( + stripped_line.startswith(prefix) + for prefix in ["Pointer", "const Pointer", "L-value", "volatile"] + ): + logger.error("Unrecognized pointer attribute: %s", line[:-1]) + + def read_mfunction_line(self, line: str): + """ + The layout is not consistent, so we want to be as robust as possible here. + - Example 1: + Return type = T_LONG(0012), Call type = C Near + Func attr = none + - Example 2: + Return type = T_CHAR(0010), Class type = 0x101A, This type = 0x101B, + Call type = ThisCall, Func attr = none + """ + + obj = self.keys[self.last_key] + + key_value_pairs = line.split(",") + for pair in key_value_pairs: + if pair.isspace(): + continue + obj |= self.parse_function_attribute(pair) + + def parse_function_attribute(self, pair: str) -> dict[str, str]: + for attribute_regex in self.LF_MFUNCTION_ATTRIBUTES: + if (match := attribute_regex.match(pair)) is not None: + return match.groupdict() + logger.error("Unknown attribute in function: %s", pair) + return {} + + def read_enum_line(self, line: str): + obj = self.keys[self.last_key] + + # We need special comma handling because commas may appear in the name. + # Splitting by "," yields the wrong result. + enum_attributes = line.split(", ") + for pair in enum_attributes: + if pair.endswith(","): + pair = pair[:-1] + if pair.isspace(): + continue + obj |= self.parse_enum_attribute(pair) + + def parse_enum_attribute(self, attribute: str) -> dict[str, Any]: + for attribute_regex in self.LF_ENUM_ATTRIBUTES: + if (match := attribute_regex.match(attribute)) is not None: + return match.groupdict() + if attribute == "NESTED": + return {"is_nested": True} + if attribute == "FORWARD REF": + return {"is_forward_ref": True} + if attribute.startswith("UDT"): + match = self.LF_ENUM_UDT.match(attribute) + assert match is not None + return {"udt": normalize_type_id(match.group("udt"))} + if (match := self.LF_ENUM_TYPES.match(attribute)) is not None: + result = match.groupdict() + result["underlying_type"] = normalize_type_id(result["underlying_type"]) + return result + logger.error("Unknown attribute in enum: %s", attribute) + return {} + + def read_union_line(self, line: str): + """This is a rather barebones handler, only parsing the size""" + if (match := self.LF_UNION_LINE.match(line)) is None: + raise AssertionError(f"Unhandled in union: {line}") + self._set("name", match.group("name")) + if match.group("field_type") == "0x0000": + self._set("is_forward_ref", True) + + self._set("size", int(match.group("size"))) + self._set("udt", normalize_type_id(match.group("udt"))) diff --git a/tools/isledecomp/tests/test_cvdump_types.py b/tools/isledecomp/tests/test_cvdump_types.py index e90cff0f..e271040c 100644 --- a/tools/isledecomp/tests/test_cvdump_types.py +++ b/tools/isledecomp/tests/test_cvdump_types.py @@ -9,6 +9,21 @@ ) TEST_LINES = """ +0x1018 : Length = 18, Leaf = 0x1201 LF_ARGLIST argument count = 3 + list[0] = 0x100D + list[1] = 0x1016 + list[2] = 0x1017 + +0x1019 : Length = 14, Leaf = 0x1008 LF_PROCEDURE + Return type = T_LONG(0012), Call type = C Near + Func attr = none + # Parms = 3, Arg list type = 0x1018 + +0x101e : Length = 26, Leaf = 0x1009 LF_MFUNCTION + Return type = T_CHAR(0010), Class type = 0x101A, This type = 0x101B, + Call type = ThisCall, Func attr = none + Parms = 2, Arg list type = 0x101d, This adjust = 0 + 0x1028 : Length = 10, Leaf = 0x1001 LF_MODIFIER const, modifies type T_REAL32(0040) @@ -47,16 +62,16 @@ Element type = T_UCHAR(0020) Index type = T_SHORT(0011) length = 8 - Name = + Name = 0x10ea : Length = 14, Leaf = 0x1503 LF_ARRAY Element type = 0x1028 Index type = T_SHORT(0011) length = 12 - Name = + Name = 0x11f0 : Length = 30, Leaf = 0x1504 LF_CLASS - # members = 0, field list type 0x0000, FORWARD REF, + # members = 0, field list type 0x0000, FORWARD REF, Derivation list type 0x0000, VT shape type 0x0000 Size = 0, class name = MxRect32, UDT(0x00001214) @@ -98,22 +113,22 @@ member name = 'm_bottom' 0x1214 : Length = 30, Leaf = 0x1504 LF_CLASS - # members = 34, field list type 0x1213, CONSTRUCTOR, OVERLOAD, + # members = 34, field list type 0x1213, CONSTRUCTOR, OVERLOAD, Derivation list type 0x0000, VT shape type 0x0000 Size = 16, class name = MxRect32, UDT(0x00001214) 0x1220 : Length = 30, Leaf = 0x1504 LF_CLASS - # members = 0, field list type 0x0000, FORWARD REF, + # members = 0, field list type 0x0000, FORWARD REF, Derivation list type 0x0000, VT shape type 0x0000 Size = 0, class name = MxCore, UDT(0x00004060) 0x14db : Length = 30, Leaf = 0x1504 LF_CLASS - # members = 0, field list type 0x0000, FORWARD REF, + # members = 0, field list type 0x0000, FORWARD REF, Derivation list type 0x0000, VT shape type 0x0000 Size = 0, class name = MxString, UDT(0x00004db6) 0x19b0 : Length = 34, Leaf = 0x1505 LF_STRUCTURE - # members = 0, field list type 0x0000, FORWARD REF, + # members = 0, field list type 0x0000, FORWARD REF, Derivation list type 0x0000, VT shape type 0x0000 Size = 0, class name = ROIColorAlias, UDT(0x00002a76) @@ -123,6 +138,12 @@ length = 440 Name = +0x2339 : Length = 26, Leaf = 0x1506 LF_UNION + # members = 0, field list type 0x0000, FORWARD REF, Size = 0 ,class name = FlagBitfield, UDT(0x00002e85) + +0x2e85 : Length = 26, Leaf = 0x1506 LF_UNION + # members = 8, field list type 0x2e84, Size = 1 ,class name = FlagBitfield, UDT(0x00002e85) + 0x2a75 : Length = 98, Leaf = 0x1203 LF_FIELDLIST list[0] = LF_MEMBER, public, type = T_32PRCHAR(0470), offset = 0 member name = 'm_name' @@ -136,18 +157,18 @@ member name = 'm_unk0x10' 0x2a76 : Length = 34, Leaf = 0x1505 LF_STRUCTURE - # members = 5, field list type 0x2a75, + # members = 5, field list type 0x2a75, Derivation list type 0x0000, VT shape type 0x0000 Size = 20, class name = ROIColorAlias, UDT(0x00002a76) 0x22d4 : Length = 154, Leaf = 0x1203 LF_FIELDLIST list[0] = LF_VFUNCTAB, type = 0x20FC list[1] = LF_METHOD, count = 3, list = 0x22D0, name = 'MxVariable' - list[2] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F0F, + list[2] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F0F, vfptr offset = 0, name = 'GetValue' - list[3] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F10, + list[3] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F10, vfptr offset = 4, name = 'SetValue' - list[4] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F11, + list[4] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F11, vfptr offset = 8, name = '~MxVariable' list[5] = LF_ONEMETHOD, public, VANILLA, index = 0x22D3, name = 'GetKey' list[6] = LF_MEMBER, protected, type = 0x14DB, offset = 4 @@ -156,10 +177,15 @@ member name = 'm_value' 0x22d5 : Length = 34, Leaf = 0x1504 LF_CLASS - # members = 10, field list type 0x22d4, CONSTRUCTOR, + # members = 10, field list type 0x22d4, CONSTRUCTOR, Derivation list type 0x0000, VT shape type 0x20fb Size = 36, class name = MxVariable, UDT(0x00004041) +0x3c45 : Length = 50, Leaf = 0x1203 LF_FIELDLIST + list[0] = LF_ENUMERATE, public, value = 1, name = 'c_read' + list[1] = LF_ENUMERATE, public, value = 2, name = 'c_write' + list[2] = LF_ENUMERATE, public, value = 4, name = 'c_text' + 0x3cc2 : Length = 38, Leaf = 0x1507 LF_ENUM # members = 64, type = T_INT4(0074) field list type 0x3cc1 NESTED, enum name = JukeBox::JukeBoxScript, UDT(0x00003cc2) @@ -171,22 +197,22 @@ 0x405f : Length = 158, Leaf = 0x1203 LF_FIELDLIST list[0] = LF_VFUNCTAB, type = 0x2090 list[1] = LF_ONEMETHOD, public, VANILLA, index = 0x176A, name = 'MxCore' - list[2] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x176A, + list[2] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x176A, vfptr offset = 0, name = '~MxCore' - list[3] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x176B, + list[3] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x176B, vfptr offset = 4, name = 'Notify' - list[4] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x2087, + list[4] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x2087, vfptr offset = 8, name = 'Tickle' - list[5] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x202F, + list[5] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x202F, vfptr offset = 12, name = 'ClassName' - list[6] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x2030, + list[6] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x2030, vfptr offset = 16, name = 'IsA' list[7] = LF_ONEMETHOD, public, VANILLA, index = 0x2091, name = 'GetId' list[8] = LF_MEMBER, private, type = T_UINT4(0075), offset = 4 member name = 'm_id' 0x4060 : Length = 30, Leaf = 0x1504 LF_CLASS - # members = 9, field list type 0x405f, CONSTRUCTOR, + # members = 9, field list type 0x405f, CONSTRUCTOR, Derivation list type 0x0000, VT shape type 0x1266 Size = 8, class name = MxCore, UDT(0x00004060) @@ -194,7 +220,7 @@ Element type = 0x3CC2 Index type = T_SHORT(0011) length = 24 - Name = + Name = 0x432f : Length = 14, Leaf = 0x1503 LF_ARRAY Element type = T_INT4(0074) @@ -220,7 +246,7 @@ member name = 'm_length' 0x4db6 : Length = 30, Leaf = 0x1504 LF_CLASS - # members = 16, field list type 0x4db5, CONSTRUCTOR, OVERLOAD, + # members = 16, field list type 0x4db5, CONSTRUCTOR, OVERLOAD, Derivation list type 0x0000, VT shape type 0x1266 Size = 16, class name = MxString, UDT(0x00004db6) """ @@ -235,7 +261,7 @@ def types_parser_fixture(): return parser -def test_basic_parsing(parser): +def test_basic_parsing(parser: CvdumpTypesParser): obj = parser.keys["0x4db6"] assert obj["type"] == "LF_CLASS" assert obj["name"] == "MxString" @@ -244,7 +270,7 @@ def test_basic_parsing(parser): assert len(parser.keys["0x4db5"]["members"]) == 2 -def test_scalar_types(parser): +def test_scalar_types(parser: CvdumpTypesParser): """Full tests on the scalar_* methods are in another file. Here we are just testing the passthrough of the "T_" types.""" assert parser.get("T_CHAR").name is None @@ -254,7 +280,7 @@ def test_scalar_types(parser): assert parser.get("T_32PVOID").size == 4 -def test_resolve_forward_ref(parser): +def test_resolve_forward_ref(parser: CvdumpTypesParser): # Non-forward ref assert parser.get("0x22d5").name == "MxVariable" # Forward ref @@ -262,7 +288,7 @@ def test_resolve_forward_ref(parser): assert parser.get("0x14db").size == 16 -def test_members(parser): +def test_members(parser: CvdumpTypesParser): """Return the list of items to compare for a given complex type. If the class has a superclass, add those members too.""" # MxCore field list @@ -284,7 +310,7 @@ def test_members(parser): ] -def test_members_recursive(parser): +def test_members_recursive(parser: CvdumpTypesParser): """Make sure that we unwrap the dependency tree correctly.""" # MxVariable field list assert parser.get_scalars("0x22d4") == [ @@ -300,7 +326,7 @@ def test_members_recursive(parser): ] -def test_struct(parser): +def test_struct(parser: CvdumpTypesParser): """Basic test for converting type into struct.unpack format string.""" # MxCore: vftable and uint32. The vftable pointer is read as uint32. assert parser.get_format_string("0x4060") == "