Add Ghidra function import script (#909)

* Add draft for Ghidra function import script

* feature: Basic PDB analysis [skip ci]

This is a draft with a lot of open questions left. Please do not merge

* Refactor: Introduce submodules and reload remedy

* refactor types and make them Python 3.9 compatible

* run black

* WIP: save progress

* fix types and small type safety violations

* fix another Python 3.9 syntax incompatibility

* Implement struct imports [skip ci]

- This code is still in dire need of refactoring and tests
- There are only single-digit issues left, and 2600 functions can be imported
- The biggest remaining error is mismatched stacks

* Refactor, implement enums, fix lots of bugs

* fix Python 3.9 issue

* refactor: address review comments

Not sure why VS Code suddenly decides to remove some empty spaces, but they don't make sense anyway

* add unit tests for new type parsers, fix linter issue

* refactor: db access from pdb_extraction.py

* Fix stack layout offset error

* fix: Undo incorrect reference change

* Fix CI issue

* Improve READMEs (fix typos, add information)

---------

Co-authored-by: jonschz <jonschz@users.noreply.github.com>
This commit is contained in:
jonschz 2024-06-09 14:41:24 +02:00 committed by GitHub
parent 88805f9fcb
commit f26c30974a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 1824 additions and 114 deletions

1
.gitignore vendored
View file

@ -19,3 +19,4 @@ LEGO1.DLL
LEGO1PROGRESS.*
ISLEPROGRESS.*
*.pyc
tools/ghidra_scripts/import.log

View file

@ -63,11 +63,11 @@ ignore-patterns=^\.#
# (useful for modules/projects where namespaces are manipulated during runtime
# and thus existing member attributes cannot be deduced by static analysis). It
# supports qualified module names, as well as Unix pattern matching.
ignored-modules=
ignored-modules=ghidra
# Python code to execute, usually for sys.path manipulation such as
# pygtk.require().
#init-hook=
init-hook='import sys; sys.path.append("tools/isledecomp"); sys.path.append("tools/ghidra_scripts")'
# Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the
# number of processors available to use, and will cap the count on Windows to

View file

@ -174,7 +174,7 @@ pip install -r tools/requirements.txt
## Testing
`isledecomp` comes with a suite of tests. Install `pylint` and run it, passing in the directory:
`isledecomp` comes with a suite of tests. Install `pytest` and run it, passing in the directory:
```
pip install pytest
@ -189,7 +189,7 @@ In order to keep the code clean and consistent, we use `pylint` and `black`:
### Run pylint (ignores build and virtualenv)
`pylint tools/ --ignore=build,bin,lib`
`pylint tools/ --ignore=build,ncc`
### Check code formatting without rewriting files

View file

@ -0,0 +1,25 @@
# Ghidra Scripts
The scripts in this directory provide additional functionality in Ghidra, e.g. imports of symbols and types from the PDB debug symbol file.
## Setup
### Ghidrathon
Since these scripts and its dependencies are written in Python 3, [Ghidrathon](https://github.com/mandiant/Ghidrathon) must be installed first. Follow the instructions and install a recent build (these scripts were tested with Python 3.12 and Ghidrathon v4.0.0).
### Script Directory
- In Ghidra, _Open Window -> Script Manager_.
- Click the _Manage Script Directories_ button on the top right.
- Click the _Add_ (Plus icon) button and select this file's parent directory.
- Close the window and click the _Refresh_ button.
- This script should now be available under the folder _LEGO1_.
### Virtual environment
As of now, there must be a Python virtual environment set up under `$REPOSITORY_ROOT/.venv`, and the dependencies of `isledecomp` must be installed there, see [here](../README.md#tooling).
## Development
- Type hints for Ghidra (optional): Download a recent release from https://github.com/VDOO-Connected-Trust/ghidra-pyi-generator,
unpack it somewhere, and `pip install` that directory in this virtual environment. This provides types and headers for Python.
Be aware that some of these files contain errors - in particular, `from typing import overload` seems to be missing everywhere, leading to spurious type errors.
- Note that the imported modules persist across multiple runs of the script (see [here](https://github.com/mandiant/Ghidrathon/issues/103)).
If you indend to modify an imported library, you have to use `import importlib; importlib.reload(${library})` or restart Ghidra for your changes to have any effect. Unfortunately, even that is not perfectly reliable, so you may still have to restart Ghidra for some changes in `isledecomp` to be applied.

View file

@ -0,0 +1,283 @@
# Imports types and function signatures from debug symbols (PDB file) of the recompilation.
#
# This script uses Python 3 and therefore requires Ghidrathon to be installed in Ghidra (see https://github.com/mandiant/Ghidrathon).
# Furthermore, the virtual environment must be set up beforehand under $REPOSITORY_ROOT/.venv, and all required packages must be installed
# (see $REPOSITORY_ROOT/tools/README.md).
# Also, the Python version of the virtual environment must probably match the Python version used for Ghidrathon.
# @author J. Schulz
# @category LEGO1
# @keybinding
# @menupath
# @toolbar
# In order to make this code run both within and outside of Ghidra, the import order is rather unorthodox in this file.
# That is why some of the lints below are disabled.
# pylint: disable=wrong-import-position,ungrouped-imports
# pylint: disable=undefined-variable # need to disable this one globally because pylint does not understand e.g. `askYesNo()``
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
import importlib
from dataclasses import dataclass, field
import logging.handlers
import sys
import logging
from pathlib import Path
import traceback
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import ghidra
from lego_util.headers import * # pylint: disable=wildcard-import # these are just for headers
logger = logging.getLogger(__name__)
def reload_module(module: str):
"""
Due to a a quirk in Jep (used by Ghidrathon), imported modules persist for the lifetime of the Ghidra process
and are not reloaded when relaunching the script. Therefore, in order to facilitate development
we force reload all our own modules at startup. See also https://github.com/mandiant/Ghidrathon/issues/103.
Note that as of 2024-05-30, this remedy does not work perfectly (yet): Some changes in isledecomp are
still not detected correctly and require a Ghidra restart to be applied.
"""
importlib.reload(importlib.import_module(module))
reload_module("lego_util.statistics")
from lego_util.statistics import Statistics
@dataclass
class Globals:
verbose: bool
loglevel: int
running_from_ghidra: bool = False
# statistics
statistics: Statistics = field(default_factory=Statistics)
# hard-coded settings that we don't want to prompt in Ghidra every time
GLOBALS = Globals(
verbose=False,
# loglevel=logging.INFO,
loglevel=logging.DEBUG,
)
def setup_logging():
logging.root.handlers.clear()
formatter = logging.Formatter("%(levelname)-8s %(message)s")
# formatter = logging.Formatter("%(name)s %(levelname)-8s %(message)s") # use this to identify loggers
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(formatter)
file_handler = logging.FileHandler(
Path(__file__).absolute().parent.joinpath("import.log"), mode="w"
)
file_handler.setFormatter(formatter)
logging.root.setLevel(GLOBALS.loglevel)
logging.root.addHandler(stdout_handler)
logging.root.addHandler(file_handler)
logger.info("Starting import...")
# This script can be run both from Ghidra and as a standalone.
# In the latter case, only the PDB parser will be used.
setup_logging()
try:
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.util.exception import CancelledException
GLOBALS.running_from_ghidra = True
except ImportError as importError:
logger.error(
"Failed to import Ghidra functions, doing a dry run for the source code parser. "
"Has this script been launched from Ghidra?"
)
logger.debug("Precise import error:", exc_info=importError)
GLOBALS.running_from_ghidra = False
CancelledException = None
def get_repository_root():
return Path(__file__).absolute().parent.parent.parent
def add_python_path(path: str):
"""
Scripts in Ghidra are executed from the tools/ghidra_scripts directory. We need to add
a few more paths to the Python path so we can import the other libraries.
"""
venv_path = get_repository_root().joinpath(path)
logger.info("Adding %s to Python Path", venv_path)
assert venv_path.exists()
sys.path.insert(1, str(venv_path))
# We need to quote the types here because they might not exist when running without Ghidra
def import_function_into_ghidra(
api: "FlatProgramAPI",
match_info: "MatchInfo",
signature: "FunctionSignature",
type_importer: "PdbTypeImporter",
):
hex_original_address = f"{match_info.orig_addr:x}"
# Find the Ghidra function at that address
ghidra_address = getAddressFactory().getAddress(hex_original_address)
# pylint: disable=possibly-used-before-assignment
function_importer = PdbFunctionImporter(api, match_info, signature, type_importer)
ghidra_function = getFunctionAt(ghidra_address)
if ghidra_function is None:
ghidra_function = createFunction(ghidra_address, "temp")
assert (
ghidra_function is not None
), f"Failed to create function at {ghidra_address}"
logger.info("Created new function at %s", ghidra_address)
logger.debug("Start handling function '%s'", function_importer.get_full_name())
if function_importer.matches_ghidra_function(ghidra_function):
logger.info(
"Skipping function '%s', matches already",
function_importer.get_full_name(),
)
return
logger.debug(
"Modifying function %s at 0x%s",
function_importer.get_full_name(),
hex_original_address,
)
function_importer.overwrite_ghidra_function(ghidra_function)
GLOBALS.statistics.functions_changed += 1
def process_functions(extraction: "PdbFunctionExtractor"):
func_signatures = extraction.get_function_list()
if not GLOBALS.running_from_ghidra:
logger.info("Completed the dry run outside Ghidra.")
return
api = FlatProgramAPI(currentProgram())
# pylint: disable=possibly-used-before-assignment
type_importer = PdbTypeImporter(api, extraction)
for match_info, signature in func_signatures:
try:
import_function_into_ghidra(api, match_info, signature, type_importer)
GLOBALS.statistics.successes += 1
except Lego1Exception as e:
log_and_track_failure(match_info.name, e)
except RuntimeError as e:
cause = e.args[0]
if CancelledException is not None and isinstance(cause, CancelledException):
# let Ghidra's CancelledException pass through
logging.critical("Import aborted by the user.")
return
log_and_track_failure(match_info.name, cause, unexpected=True)
logger.error(traceback.format_exc())
except Exception as e: # pylint: disable=broad-exception-caught
log_and_track_failure(match_info.name, e, unexpected=True)
logger.error(traceback.format_exc())
def log_and_track_failure(
function_name: Optional[str], error: Exception, unexpected: bool = False
):
if GLOBALS.statistics.track_failure_and_tell_if_new(error):
logger.error(
"%s(): %s%s",
function_name,
"Unexpected error: " if unexpected else "",
error,
)
def main():
repo_root = get_repository_root()
origfile_path = repo_root.joinpath("LEGO1.DLL")
build_path = repo_root.joinpath("build")
recompiledfile_path = build_path.joinpath("LEGO1.DLL")
pdb_path = build_path.joinpath("LEGO1.pdb")
if not GLOBALS.verbose:
logging.getLogger("isledecomp.bin").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.core").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.db").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.WARNING)
logging.getLogger("isledecomp.cvdump.symbols").setLevel(logging.WARNING)
logger.info("Starting comparison")
with Bin(str(origfile_path), find_str=True) as origfile, Bin(
str(recompiledfile_path)
) as recompfile:
isle_compare = IsleCompare(origfile, recompfile, str(pdb_path), str(repo_root))
logger.info("Comparison complete.")
# try to acquire matched functions
migration = PdbFunctionExtractor(isle_compare)
try:
process_functions(migration)
finally:
if GLOBALS.running_from_ghidra:
GLOBALS.statistics.log()
logger.info("Done")
# sys.path is not reset after running the script, so we should restore it
sys_path_backup = sys.path.copy()
try:
# make modules installed in the venv available in Ghidra
add_python_path(".venv/Lib/site-packages")
# This one is needed when isledecomp is installed in editable mode in the venv
add_python_path("tools/isledecomp")
import setuptools # pylint: disable=unused-import # required to fix a distutils issue in Python 3.12
reload_module("isledecomp")
from isledecomp import Bin
reload_module("isledecomp.compare")
from isledecomp.compare import Compare as IsleCompare
reload_module("isledecomp.compare.db")
from isledecomp.compare.db import MatchInfo
reload_module("lego_util.exceptions")
from lego_util.exceptions import Lego1Exception
reload_module("lego_util.pdb_extraction")
from lego_util.pdb_extraction import (
PdbFunctionExtractor,
FunctionSignature,
)
if GLOBALS.running_from_ghidra:
reload_module("lego_util.ghidra_helper")
reload_module("lego_util.function_importer")
from lego_util.function_importer import PdbFunctionImporter
reload_module("lego_util.type_importer")
from lego_util.type_importer import PdbTypeImporter
if __name__ == "__main__":
main()
finally:
sys.path = sys_path_backup

View file

@ -0,0 +1,47 @@
class Lego1Exception(Exception):
"""
Our own base class for exceptions.
Makes it easier to distinguish expected and unexpected errors.
"""
class TypeNotFoundError(Lego1Exception):
def __str__(self):
return f"Type not found in PDB: {self.args[0]}"
class TypeNotFoundInGhidraError(Lego1Exception):
def __str__(self):
return f"Type not found in Ghidra: {self.args[0]}"
class TypeNotImplementedError(Lego1Exception):
def __str__(self):
return f"Import not implemented for type: {self.args[0]}"
class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception):
def __init__(self, namespaceHierachy: list[str]):
super().__init__(namespaceHierachy)
def get_namespace_str(self) -> str:
return "::".join(self.args[0])
def __str__(self):
return f"Class or namespace not found in Ghidra: {self.get_namespace_str()}"
class MultipleTypesFoundInGhidraError(Lego1Exception):
def __str__(self):
return (
f"Found multiple types matching '{self.args[0]}' in Ghidra: {self.args[1]}"
)
class StackOffsetMismatchError(Lego1Exception):
pass
class StructModificationError(Lego1Exception):
def __str__(self):
return f"Failed to modify struct in Ghidra: '{self.args[0]}'\nDetailed error: {self.__cause__}"

View file

@ -0,0 +1,241 @@
# This file can only be imported successfully when run from Ghidra using Ghidrathon.
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
import logging
from typing import Optional
from ghidra.program.model.listing import Function, Parameter
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.listing import ParameterImpl
from ghidra.program.model.symbol import SourceType
from isledecomp.compare.db import MatchInfo
from lego_util.pdb_extraction import (
FunctionSignature,
CppRegisterSymbol,
CppStackSymbol,
)
from lego_util.ghidra_helper import (
get_ghidra_namespace,
sanitize_name,
)
from lego_util.exceptions import StackOffsetMismatchError
from lego_util.type_importer import PdbTypeImporter
logger = logging.getLogger(__name__)
# pylint: disable=too-many-instance-attributes
class PdbFunctionImporter:
"""A representation of a function from the PDB with each type replaced by a Ghidra type instance."""
def __init__(
self,
api: FlatProgramAPI,
match_info: MatchInfo,
signature: FunctionSignature,
type_importer: "PdbTypeImporter",
):
self.api = api
self.match_info = match_info
self.signature = signature
self.type_importer = type_importer
if signature.class_type is not None:
# Import the base class so the namespace exists
self.type_importer.import_pdb_type_into_ghidra(signature.class_type)
assert match_info.name is not None
colon_split = sanitize_name(match_info.name).split("::")
self.name = colon_split.pop()
namespace_hierachy = colon_split
self.namespace = get_ghidra_namespace(api, namespace_hierachy)
self.return_type = type_importer.import_pdb_type_into_ghidra(
signature.return_type
)
self.arguments = [
ParameterImpl(
f"param{index}",
type_importer.import_pdb_type_into_ghidra(type_name),
api.getCurrentProgram(),
)
for (index, type_name) in enumerate(signature.arglist)
]
@property
def call_type(self):
return self.signature.call_type
@property
def stack_symbols(self):
return self.signature.stack_symbols
def get_full_name(self) -> str:
return f"{self.namespace.getName()}::{self.name}"
def matches_ghidra_function(self, ghidra_function: Function) -> bool:
"""Checks whether this function declaration already matches the description in Ghidra"""
name_match = self.name == ghidra_function.getName(False)
namespace_match = self.namespace == ghidra_function.getParentNamespace()
return_type_match = self.return_type == ghidra_function.getReturnType()
# match arguments: decide if thiscall or not
thiscall_matches = (
self.signature.call_type == ghidra_function.getCallingConventionName()
)
if thiscall_matches:
if self.signature.call_type == "__thiscall":
args_match = self._matches_thiscall_parameters(ghidra_function)
else:
args_match = self._matches_non_thiscall_parameters(ghidra_function)
else:
args_match = False
logger.debug(
"Matches: namespace=%s name=%s return_type=%s thiscall=%s args=%s",
namespace_match,
name_match,
return_type_match,
thiscall_matches,
args_match,
)
return (
name_match
and namespace_match
and return_type_match
and thiscall_matches
and args_match
)
def _matches_non_thiscall_parameters(self, ghidra_function: Function) -> bool:
return self._parameter_lists_match(ghidra_function.getParameters())
def _matches_thiscall_parameters(self, ghidra_function: Function) -> bool:
ghidra_params = list(ghidra_function.getParameters())
# remove the `this` argument which we don't generate ourselves
ghidra_params.pop(0)
return self._parameter_lists_match(ghidra_params)
def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool:
if len(self.arguments) != len(ghidra_params):
logger.info("Mismatching argument count")
return False
for this_arg, ghidra_arg in zip(self.arguments, ghidra_params):
# compare argument types
if this_arg.getDataType() != ghidra_arg.getDataType():
logger.debug(
"Mismatching arg type: expected %s, found %s",
this_arg.getDataType(),
ghidra_arg.getDataType(),
)
return False
# compare argument names
stack_match = self.get_matching_stack_symbol(ghidra_arg.getStackOffset())
if stack_match is None:
logger.debug("Not found on stack: %s", ghidra_arg)
return False
# "__formal" is the placeholder for arguments without a name
if (
stack_match.name != ghidra_arg.getName()
and not stack_match.name.startswith("__formal")
):
logger.debug(
"Argument name mismatch: expected %s, found %s",
stack_match.name,
ghidra_arg.getName(),
)
return False
return True
def overwrite_ghidra_function(self, ghidra_function: Function):
"""Replace the function declaration in Ghidra by the one derived from C++."""
ghidra_function.setName(self.name, SourceType.USER_DEFINED)
ghidra_function.setParentNamespace(self.namespace)
ghidra_function.setReturnType(self.return_type, SourceType.USER_DEFINED)
ghidra_function.setCallingConvention(self.call_type)
ghidra_function.replaceParameters(
Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS,
True,
SourceType.USER_DEFINED,
self.arguments,
)
# When we set the parameters, Ghidra will generate the layout.
# Now we read them again and match them against the stack layout in the PDB,
# both to verify and to set the parameter names.
ghidra_parameters: list[Parameter] = ghidra_function.getParameters()
# Try to add Ghidra function names
for index, param in enumerate(ghidra_parameters):
if param.isStackVariable():
self._rename_stack_parameter(index, param)
else:
if param.getName() == "this":
# 'this' parameters are auto-generated and cannot be changed
continue
# Appears to never happen - could in theory be relevant to __fastcall__ functions,
# which we haven't seen yet
logger.warning("Unhandled register variable in %s", self.get_full_name)
continue
def _rename_stack_parameter(self, index: int, param: Parameter):
match = self.get_matching_stack_symbol(param.getStackOffset())
if match is None:
raise StackOffsetMismatchError(
f"Could not find a matching symbol at offset {param.getStackOffset()} in {self.get_full_name()}"
)
if match.data_type == "T_NOTYPE(0000)":
logger.warning("Skipping stack parameter of type NOTYPE")
return
if param.getDataType() != self.type_importer.import_pdb_type_into_ghidra(
match.data_type
):
logger.error(
"Type mismatch for parameter: %s in Ghidra, %s in PDB", param, match
)
return
name = match.name
if name == "__formal":
# these can cause name collisions if multiple ones are present
name = f"__formal_{index}"
param.setName(name, SourceType.USER_DEFINED)
def get_matching_stack_symbol(self, stack_offset: int) -> Optional[CppStackSymbol]:
return next(
(
symbol
for symbol in self.stack_symbols
if isinstance(symbol, CppStackSymbol)
and symbol.stack_offset == stack_offset
),
None,
)
def get_matching_register_symbol(
self, register: str
) -> Optional[CppRegisterSymbol]:
return next(
(
symbol
for symbol in self.stack_symbols
if isinstance(symbol, CppRegisterSymbol) and symbol.register == register
),
None,
)

View file

@ -0,0 +1,100 @@
"""A collection of helper functions for the interaction with Ghidra."""
import logging
from lego_util.exceptions import (
ClassOrNamespaceNotFoundInGhidraError,
TypeNotFoundInGhidraError,
MultipleTypesFoundInGhidraError,
)
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
from ghidra.program.model.data import PointerDataType
from ghidra.program.model.data import DataTypeConflictHandler
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.data import DataType
from ghidra.program.model.symbol import Namespace
logger = logging.getLogger(__name__)
def get_ghidra_type(api: FlatProgramAPI, type_name: str):
"""
Searches for the type named `typeName` in Ghidra.
Raises:
- NotFoundInGhidraError
- MultipleTypesFoundInGhidraError
"""
result = api.getDataTypes(type_name)
if len(result) == 0:
raise TypeNotFoundInGhidraError(type_name)
if len(result) == 1:
return result[0]
raise MultipleTypesFoundInGhidraError(type_name, result)
def add_pointer_type(api: FlatProgramAPI, pointee: DataType) -> DataType:
new_data_type = PointerDataType(pointee)
new_data_type.setCategoryPath(pointee.getCategoryPath())
result_data_type = (
api.getCurrentProgram()
.getDataTypeManager()
.addDataType(new_data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
if result_data_type is not new_data_type:
logger.debug(
"New pointer replaced by existing one. Fresh pointer: %s (class: %s)",
result_data_type,
result_data_type.__class__,
)
return result_data_type
def get_ghidra_namespace(
api: FlatProgramAPI, namespace_hierachy: list[str]
) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)
if namespace is None:
raise ClassOrNamespaceNotFoundInGhidraError(namespace_hierachy)
return namespace
def create_ghidra_namespace(
api: FlatProgramAPI, namespace_hierachy: list[str]
) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)
if namespace is None:
namespace = api.createNamespace(namespace, part)
return namespace
def sanitize_name(name: str) -> str:
"""
Takes a full class or function name and replaces characters not accepted by Ghidra.
Applies mostly to templates and names like `vbase destructor`.
"""
new_class_name = (
name.replace("<", "[")
.replace(">", "]")
.replace("*", "#")
.replace(" ", "_")
.replace("`", "'")
)
if "<" in name:
new_class_name = "_template_" + new_class_name
if new_class_name != name:
logger.warning(
"Class or function name contains characters forbidden by Ghidra, changing from '%s' to '%s'",
name,
new_class_name,
)
return new_class_name

View file

@ -0,0 +1,19 @@
from typing import TypeVar
import ghidra
# pylint: disable=invalid-name,unused-argument
T = TypeVar("T")
# from ghidra.app.script.GhidraScript
def currentProgram() -> "ghidra.program.model.listing.Program": ...
def getAddressFactory() -> " ghidra.program.model.address.AddressFactory": ...
def state() -> "ghidra.app.script.GhidraState": ...
def askChoice(title: str, message: str, choices: list[T], defaultValue: T) -> T: ...
def askYesNo(title: str, question: str) -> bool: ...
def getFunctionAt(
entryPoint: ghidra.program.model.address.Address,
) -> ghidra.program.model.listing.Function: ...
def createFunction(
entryPoint: ghidra.program.model.address.Address, name: str
) -> ghidra.program.model.listing.Function: ...

View file

@ -0,0 +1,166 @@
from dataclasses import dataclass
import re
from typing import Any, Optional
import logging
from isledecomp.cvdump.symbols import SymbolsEntry
from isledecomp.compare import Compare as IsleCompare
from isledecomp.compare.db import MatchInfo
logger = logging.getLogger(__file__)
@dataclass
class CppStackOrRegisterSymbol:
name: str
data_type: str
@dataclass
class CppStackSymbol(CppStackOrRegisterSymbol):
stack_offset: int
"""Should have a value iff `symbol_type=='S_BPREL32'."""
@dataclass
class CppRegisterSymbol(CppStackOrRegisterSymbol):
register: str
"""Should have a value iff `symbol_type=='S_REGISTER'.` Should always be set/converted to lowercase."""
@dataclass
class FunctionSignature:
original_function_symbol: SymbolsEntry
call_type: str
arglist: list[str]
return_type: str
class_type: Optional[str]
stack_symbols: list[CppStackOrRegisterSymbol]
class PdbFunctionExtractor:
"""
Extracts all information on a given function from the parsed PDB
and prepares the data for the import in Ghidra.
"""
def __init__(self, compare: IsleCompare):
self.compare = compare
scalar_type_regex = re.compile(r"t_(?P<typename>\w+)(?:\((?P<type_id>\d+)\))?")
_call_type_map = {
"ThisCall": "__thiscall",
"C Near": "__thiscall",
"STD Near": "__stdcall",
}
def _get_cvdump_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]:
return (
None
if type_name is None
else self.compare.cv.types.keys.get(type_name.lower())
)
def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
function_type_str = fn.func_type
if function_type_str == "T_NOTYPE(0000)":
logger.debug(
"Skipping a NOTYPE (synthetic or template + synthetic): %s", fn.name
)
return None
# get corresponding function type
function_type = self.compare.cv.types.keys.get(function_type_str.lower())
if function_type is None:
logger.error(
"Could not find function type %s for function %s", fn.func_type, fn.name
)
return None
class_type = function_type.get("class_type")
arg_list_type = self._get_cvdump_type(function_type.get("arg_list_type"))
assert arg_list_type is not None
arg_list_pdb_types = arg_list_type.get("args", [])
assert arg_list_type["argcount"] == len(arg_list_pdb_types)
stack_symbols: list[CppStackOrRegisterSymbol] = []
# for some unexplained reason, the reported stack is offset by 4 when this flag is set
stack_offset_delta = -4 if fn.frame_pointer_present else 0
for symbol in fn.stack_symbols:
if symbol.symbol_type == "S_REGISTER":
stack_symbols.append(
CppRegisterSymbol(
symbol.name,
symbol.data_type,
symbol.location,
)
)
elif symbol.symbol_type == "S_BPREL32":
stack_offset = int(symbol.location[1:-1], 16)
stack_symbols.append(
CppStackSymbol(
symbol.name,
symbol.data_type,
stack_offset + stack_offset_delta,
)
)
call_type = self._call_type_map[function_type["call_type"]]
return FunctionSignature(
original_function_symbol=fn,
call_type=call_type,
arglist=arg_list_pdb_types,
return_type=function_type["return_type"],
class_type=class_type,
stack_symbols=stack_symbols,
)
def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]:
handled = (
self.handle_matched_function(match)
for match in self.compare.get_functions()
)
return [signature for signature in handled if signature is not None]
def handle_matched_function(
self, match_info: MatchInfo
) -> Optional[tuple[MatchInfo, FunctionSignature]]:
assert match_info.orig_addr is not None
match_options = self.compare.get_match_options(match_info.orig_addr)
assert match_options is not None
if match_options.get("skip", False) or match_options.get("stub", False):
return None
function_data = next(
(
y
for y in self.compare.cvdump_analysis.nodes
if y.addr == match_info.recomp_addr
),
None,
)
if not function_data:
logger.error(
"Did not find function in nodes, skipping: %s", match_info.name
)
return None
function_symbol = function_data.symbol_entry
if function_symbol is None:
logger.debug(
"Could not find function symbol (likely a PUBLICS entry): %s",
match_info.name,
)
return None
function_signature = self.get_func_signature(function_symbol)
if function_signature is None:
return None
return match_info, function_signature

View file

@ -0,0 +1,68 @@
from dataclasses import dataclass, field
import logging
from lego_util.exceptions import (
TypeNotFoundInGhidraError,
ClassOrNamespaceNotFoundInGhidraError,
)
logger = logging.getLogger(__name__)
@dataclass
class Statistics:
functions_changed: int = 0
successes: int = 0
failures: dict[str, int] = field(default_factory=dict)
known_missing_types: dict[str, int] = field(default_factory=dict)
known_missing_namespaces: dict[str, int] = field(default_factory=dict)
def track_failure_and_tell_if_new(self, error: Exception) -> bool:
"""
Adds the error to the statistics. Returns `False` if logging the error would be redundant
(e.g. because it is a `TypeNotFoundInGhidraError` with a type that has been logged before).
"""
error_type_name = error.__class__.__name__
self.failures[error_type_name] = (
self.failures.setdefault(error_type_name, 0) + 1
)
if isinstance(error, TypeNotFoundInGhidraError):
return self._add_occurence_and_check_if_new(
self.known_missing_types, error.args[0]
)
if isinstance(error, ClassOrNamespaceNotFoundInGhidraError):
return self._add_occurence_and_check_if_new(
self.known_missing_namespaces, error.get_namespace_str()
)
# We do not have detailed tracking for other errors, so we want to log them every time
return True
def _add_occurence_and_check_if_new(self, target: dict[str, int], key: str) -> bool:
old_count = target.setdefault(key, 0)
target[key] = old_count + 1
return old_count == 0
def log(self):
logger.info("Statistics:\n~~~~~")
logger.info(
"Missing types (with number of occurences): %s\n~~~~~",
self.format_statistics(self.known_missing_types),
)
logger.info(
"Missing classes/namespaces (with number of occurences): %s\n~~~~~",
self.format_statistics(self.known_missing_namespaces),
)
logger.info("Successes: %d", self.successes)
logger.info("Failures: %s", self.failures)
logger.info("Functions changed: %d", self.functions_changed)
def format_statistics(self, stats: dict[str, int]) -> str:
if len(stats) == 0:
return "<none>"
return ", ".join(
f"{entry[0]} ({entry[1]})"
for entry in sorted(stats.items(), key=lambda x: x[1], reverse=True)
)

View file

@ -0,0 +1,313 @@
import logging
from typing import Any
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
# pylint: disable=too-many-return-statements # a `match` would be better, but for now we are stuck with Python 3.9
# pylint: disable=no-else-return # Not sure why this rule even is a thing, this is great for checking exhaustiveness
from lego_util.exceptions import (
ClassOrNamespaceNotFoundInGhidraError,
TypeNotFoundError,
TypeNotFoundInGhidraError,
TypeNotImplementedError,
StructModificationError,
)
from lego_util.ghidra_helper import (
add_pointer_type,
create_ghidra_namespace,
get_ghidra_namespace,
get_ghidra_type,
sanitize_name,
)
from lego_util.pdb_extraction import PdbFunctionExtractor
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.data import (
ArrayDataType,
CategoryPath,
DataType,
DataTypeConflictHandler,
EnumDataType,
StructureDataType,
StructureInternal,
)
from ghidra.util.task import ConsoleTaskMonitor
logger = logging.getLogger(__name__)
class PdbTypeImporter:
"""Allows PDB types to be imported into Ghidra."""
def __init__(self, api: FlatProgramAPI, extraction: PdbFunctionExtractor):
self.api = api
self.extraction = extraction
# tracks the structs/classes we have already started to import, otherwise we run into infinite recursion
self.handled_structs: set[str] = set()
self.struct_call_stack: list[str] = []
@property
def types(self):
return self.extraction.compare.cv.types
def import_pdb_type_into_ghidra(self, type_index: str) -> DataType:
"""
Recursively imports a type from the PDB into Ghidra.
@param type_index Either a scalar type like `T_INT4(...)` or a PDB reference like `0x10ba`
"""
type_index_lower = type_index.lower()
if type_index_lower.startswith("t_"):
return self._import_scalar_type(type_index_lower)
try:
type_pdb = self.extraction.compare.cv.types.keys[type_index_lower]
except KeyError as e:
raise TypeNotFoundError(
f"Failed to find referenced type '{type_index_lower}'"
) from e
type_category = type_pdb["type"]
# follow forward reference (class, struct, union)
if type_pdb.get("is_forward_ref", False):
return self._import_forward_ref_type(type_index_lower, type_pdb)
if type_category == "LF_POINTER":
return add_pointer_type(
self.api, self.import_pdb_type_into_ghidra(type_pdb["element_type"])
)
elif type_category in ["LF_CLASS", "LF_STRUCTURE"]:
return self._import_class_or_struct(type_pdb)
elif type_category == "LF_ARRAY":
return self._import_array(type_pdb)
elif type_category == "LF_ENUM":
return self._import_enum(type_pdb)
elif type_category == "LF_PROCEDURE":
logger.warning(
"Not implemented: Function-valued argument or return type will be replaced by void pointer: %s",
type_pdb,
)
return get_ghidra_type(self.api, "void")
elif type_category == "LF_UNION":
return self._import_union(type_pdb)
else:
raise TypeNotImplementedError(type_pdb)
_scalar_type_map = {
"rchar": "char",
"int4": "int",
"uint4": "uint",
"real32": "float",
"real64": "double",
}
def _scalar_type_to_cpp(self, scalar_type: str) -> str:
if scalar_type.startswith("32p"):
return f"{self._scalar_type_to_cpp(scalar_type[3:])} *"
return self._scalar_type_map.get(scalar_type, scalar_type)
def _import_scalar_type(self, type_index_lower: str) -> DataType:
if (match := self.extraction.scalar_type_regex.match(type_index_lower)) is None:
raise TypeNotFoundError(f"Type has unexpected format: {type_index_lower}")
scalar_cpp_type = self._scalar_type_to_cpp(match.group("typename"))
return get_ghidra_type(self.api, scalar_cpp_type)
def _import_forward_ref_type(
self, type_index, type_pdb: dict[str, Any]
) -> DataType:
referenced_type = type_pdb.get("udt") or type_pdb.get("modifies")
if referenced_type is None:
try:
# Example: HWND__, needs to be created manually
return get_ghidra_type(self.api, type_pdb["name"])
except TypeNotFoundInGhidraError as e:
raise TypeNotImplementedError(
f"{type_index}: forward ref without target, needs to be created manually: {type_pdb}"
) from e
logger.debug(
"Following forward reference from %s to %s",
type_index,
referenced_type,
)
return self.import_pdb_type_into_ghidra(referenced_type)
def _import_array(self, type_pdb: dict[str, Any]) -> DataType:
inner_type = self.import_pdb_type_into_ghidra(type_pdb["array_type"])
array_total_bytes: int = type_pdb["size"]
data_type_size = inner_type.getLength()
array_length, modulus = divmod(array_total_bytes, data_type_size)
assert (
modulus == 0
), f"Data type size {data_type_size} does not divide array size {array_total_bytes}"
return ArrayDataType(inner_type, array_length, 0)
def _import_union(self, type_pdb: dict[str, Any]) -> DataType:
try:
logger.debug("Dereferencing union %s", type_pdb)
union_type = get_ghidra_type(self.api, type_pdb["name"])
assert (
union_type.getLength() == type_pdb["size"]
), f"Wrong size of existing union type '{type_pdb['name']}': expected {type_pdb['size']}, got {union_type.getLength()}"
return union_type
except TypeNotFoundInGhidraError as e:
# We have so few instances, it is not worth implementing this
raise TypeNotImplementedError(
f"Writing union types is not supported. Please add by hand: {type_pdb}"
) from e
def _import_enum(self, type_pdb: dict[str, Any]) -> DataType:
underlying_type = self.import_pdb_type_into_ghidra(type_pdb["underlying_type"])
field_list = self.extraction.compare.cv.types.keys.get(type_pdb["field_type"])
assert field_list is not None, f"Failed to find field list for enum {type_pdb}"
result = EnumDataType(
CategoryPath("/imported"), type_pdb["name"], underlying_type.getLength()
)
variants: list[dict[str, Any]] = field_list["variants"]
for variant in variants:
result.add(variant["name"], variant["value"])
return result
def _import_class_or_struct(self, type_in_pdb: dict[str, Any]) -> DataType:
field_list_type: str = type_in_pdb["field_list_type"]
field_list = self.types.keys[field_list_type.lower()]
class_size: int = type_in_pdb["size"]
class_name_with_namespace: str = sanitize_name(type_in_pdb["name"])
if class_name_with_namespace in self.handled_structs:
logger.debug(
"Class has been handled or is being handled: %s",
class_name_with_namespace,
)
return get_ghidra_type(self.api, class_name_with_namespace)
logger.debug(
"--- Beginning to import class/struct '%s'", class_name_with_namespace
)
# Add as soon as we start to avoid infinite recursion
self.handled_structs.add(class_name_with_namespace)
self._get_or_create_namespace(class_name_with_namespace)
data_type = self._get_or_create_struct_data_type(
class_name_with_namespace, class_size
)
if (old_size := data_type.getLength()) != class_size:
logger.warning(
"Existing class %s had incorrect size %d. Setting to %d...",
class_name_with_namespace,
old_size,
class_size,
)
logger.info("Adding class data type %s", class_name_with_namespace)
logger.debug("Class information: %s", type_in_pdb)
data_type.deleteAll()
data_type.growStructure(class_size)
# this case happened e.g. for IUnknown, which linked to an (incorrect) existing library, and some other types as well.
# Unfortunately, we don't get proper error handling for read-only types.
# However, we really do NOT want to do this every time because the type might be self-referential and partially imported.
if data_type.getLength() != class_size:
data_type = self._delete_and_recreate_struct_data_type(
class_name_with_namespace, class_size, data_type
)
# can be missing when no new fields are declared
components: list[dict[str, Any]] = field_list.get("members") or []
super_type = field_list.get("super")
if super_type is not None:
components.insert(0, {"type": super_type, "offset": 0, "name": "base"})
for component in components:
ghidra_type = self.import_pdb_type_into_ghidra(component["type"])
logger.debug("Adding component to class: %s", component)
try:
# for better logs
data_type.replaceAtOffset(
component["offset"], ghidra_type, -1, component["name"], None
)
except Exception as e:
raise StructModificationError(type_in_pdb) from e
logger.info("Finished importing class %s", class_name_with_namespace)
return data_type
def _get_or_create_namespace(self, class_name_with_namespace: str):
colon_split = class_name_with_namespace.split("::")
class_name = colon_split[-1]
try:
get_ghidra_namespace(self.api, colon_split)
logger.debug("Found existing class/namespace %s", class_name_with_namespace)
except ClassOrNamespaceNotFoundInGhidraError:
logger.info("Creating class/namespace %s", class_name_with_namespace)
class_name = colon_split.pop()
parent_namespace = create_ghidra_namespace(self.api, colon_split)
self.api.createClass(parent_namespace, class_name)
def _get_or_create_struct_data_type(
self, class_name_with_namespace: str, class_size: int
) -> StructureInternal:
try:
data_type = get_ghidra_type(self.api, class_name_with_namespace)
logger.debug(
"Found existing data type %s under category path %s",
class_name_with_namespace,
data_type.getCategoryPath(),
)
except TypeNotFoundInGhidraError:
# Create a new struct data type
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
logger.info("Created new data type %s", class_name_with_namespace)
assert isinstance(
data_type, StructureInternal
), f"Found type sharing its name with a class/struct, but is not a struct: {class_name_with_namespace}"
return data_type
def _delete_and_recreate_struct_data_type(
self,
class_name_with_namespace: str,
class_size: int,
existing_data_type: DataType,
) -> StructureInternal:
logger.warning(
"Failed to modify data type %s. Will try to delete the existing one and re-create the imported one.",
class_name_with_namespace,
)
assert (
self.api.getCurrentProgram()
.getDataTypeManager()
.remove(existing_data_type, ConsoleTaskMonitor())
), f"Failed to delete and re-create data type {class_name_with_namespace}"
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
assert isinstance(data_type, StructureInternal) # for type checking
return data_type

View file

@ -4,7 +4,7 @@
import struct
import uuid
from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional
from typing import Any, Callable, Iterable, List, Optional
from isledecomp.bin import Bin as IsleBin, InvalidVirtualAddressError
from isledecomp.cvdump.demangler import demangle_string_const
from isledecomp.cvdump import Cvdump, CvdumpAnalysis
@ -90,7 +90,7 @@ def __init__(
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
cv = (
self.cv = (
Cvdump(self.pdb_file)
.lines()
.globals()
@ -100,9 +100,9 @@ def _load_cvdump(self):
.types()
.run()
)
res = CvdumpAnalysis(cv)
self.cvdump_analysis = CvdumpAnalysis(self.cv)
for sym in res.nodes:
for sym in self.cvdump_analysis.nodes:
# Skip nodes where we have almost no information.
# These probably came from SECTION CONTRIBUTIONS.
if sym.name() is None and sym.node_type is None:
@ -116,6 +116,7 @@ def _load_cvdump(self):
continue
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
sym.addr = addr
# If this symbol is the final one in its section, we were not able to
# estimate its size because we didn't have the total size of that section.
@ -165,7 +166,10 @@ def _load_cvdump(self):
addr, sym.node_type, sym.name(), sym.decorated_name, sym.size()
)
for (section, offset), (filename, line_no) in res.verified_lines.items():
for (section, offset), (
filename,
line_no,
) in self.cvdump_analysis.verified_lines.items():
addr = self.recomp_bin.get_abs_addr(section, offset)
self._lines_db.add_line(filename, line_no, addr)
@ -736,6 +740,9 @@ def get_vtables(self) -> List[MatchInfo]:
def get_variables(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.DATA)
def get_match_options(self, addr: int) -> Optional[dict[str, Any]]:
return self._db.get_match_options(addr)
def compare_address(self, addr: int) -> Optional[DiffReport]:
match = self._db.get_one_match(addr)
if match is None:

View file

@ -2,7 +2,7 @@
addresses/symbols that we want to compare between the original and recompiled binaries."""
import sqlite3
import logging
from typing import List, Optional
from typing import Any, List, Optional
from isledecomp.types import SymbolType
from isledecomp.cvdump.demangler import get_vtordisp_name
@ -335,7 +335,7 @@ def mark_stub(self, orig: int):
def skip_compare(self, orig: int):
self._set_opt_bool(orig, "skip")
def get_match_options(self, addr: int) -> Optional[dict]:
def get_match_options(self, addr: int) -> Optional[dict[str, Any]]:
cur = self._db.execute(
"""SELECT name, value FROM `match_options` WHERE addr = ?""", (addr,)
)

View file

@ -1,3 +1,4 @@
from .symbols import SymbolsEntry
from .analysis import CvdumpAnalysis
from .parser import CvdumpParser
from .runner import Cvdump

View file

@ -1,5 +1,7 @@
"""For collating the results from parsing cvdump.exe into a more directly useful format."""
from typing import Dict, List, Tuple, Optional
from isledecomp.cvdump import SymbolsEntry
from isledecomp.types import SymbolType
from .parser import CvdumpParser
from .demangler import demangle_string_const, demangle_vtable
@ -31,6 +33,8 @@ class CvdumpNode:
# Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be
# accurate.
section_contribution: Optional[int] = None
addr: Optional[int] = None
symbol_entry: Optional[SymbolsEntry] = None
def __init__(self, section: int, offset: int) -> None:
self.section = section
@ -87,13 +91,12 @@ class CvdumpAnalysis:
"""Collects the results from CvdumpParser into a list of nodes (i.e. symbols).
These can then be analyzed by a downstream tool."""
nodes = List[CvdumpNode]
verified_lines = Dict[Tuple[str, str], Tuple[str, str]]
verified_lines: Dict[Tuple[str, str], Tuple[str, str]]
def __init__(self, parser: CvdumpParser):
"""Read in as much information as we have from the parser.
The more sections we have, the better our information will be."""
node_dict = {}
node_dict: Dict[Tuple[int, int], CvdumpNode] = {}
# PUBLICS is our roadmap for everything that follows.
for pub in parser.publics:
@ -158,8 +161,11 @@ def __init__(self, parser: CvdumpParser):
node_dict[key].friendly_name = sym.name
node_dict[key].confirmed_size = sym.size
node_dict[key].node_type = SymbolType.FUNCTION
node_dict[key].symbol_entry = sym
self.nodes = [v for _, v in dict(sorted(node_dict.items())).items()]
self.nodes: List[CvdumpNode] = [
v for _, v in dict(sorted(node_dict.items())).items()
]
self._estimate_size()
def _estimate_size(self):

View file

@ -2,6 +2,7 @@
from typing import Iterable, Tuple
from collections import namedtuple
from .types import CvdumpTypesParser
from .symbols import CvdumpSymbolsParser
# e.g. `*** PUBLICS`
_section_change_regex = re.compile(r"\*\*\* (?P<section>[A-Z/ ]{2,})")
@ -20,11 +21,6 @@
r"^(?P<type>\w+): \[(?P<section>\w{4}):(?P<offset>\w{8})], Flags: (?P<flags>\w{8}), (?P<name>\S+)"
)
# e.g. `(00008C) S_GPROC32: [0001:00034E90], Cb: 00000007, Type: 0x1024, ViewROI::IntrinsicImportance`
_symbol_line_regex = re.compile(
r"\(\w+\) (?P<type>\S+): \[(?P<section>\w{4}):(?P<offset>\w{8})\], Cb: (?P<size>\w+), Type:\s+\S+, (?P<name>.+)"
)
# e.g. ` Debug start: 00000008, Debug end: 0000016E`
_gproc_debug_regex = re.compile(
r"\s*Debug start: (?P<start>\w{8}), Debug end: (?P<end>\w{8})"
@ -52,9 +48,6 @@
# only place you can find the C symbols (library functions, smacker, etc)
PublicsEntry = namedtuple("PublicsEntry", "type section offset flags name")
# S_GPROC32 = functions
SymbolsEntry = namedtuple("SymbolsEntry", "type section offset size name")
# (Estimated) size of any symbol
SizeRefEntry = namedtuple("SizeRefEntry", "module section offset size")
@ -72,12 +65,16 @@ def __init__(self) -> None:
self.lines = {}
self.publics = []
self.symbols = []
self.sizerefs = []
self.globals = []
self.modules = []
self.types = CvdumpTypesParser()
self.symbols_parser = CvdumpSymbolsParser()
@property
def symbols(self):
return self.symbols_parser.symbols
def _lines_section(self, line: str):
"""Parsing entries from the LINES section. We only care about the pairs of
@ -127,20 +124,6 @@ def _globals_section(self, line: str):
)
)
def _symbols_section(self, line: str):
"""We are interested in S_GPROC32 symbols only."""
if (match := _symbol_line_regex.match(line)) is not None:
if match.group("type") == "S_GPROC32":
self.symbols.append(
SymbolsEntry(
type=match.group("type"),
section=int(match.group("section"), 16),
offset=int(match.group("offset"), 16),
size=int(match.group("size"), 16),
name=match.group("name"),
)
)
def _section_contributions(self, line: str):
"""Gives the size of elements across all sections of the binary.
This is the easiest way to get the data size for .data and .rdata
@ -177,7 +160,7 @@ def read_line(self, line: str):
self.types.read_line(line)
elif self._section == "SYMBOLS":
self._symbols_section(line)
self.symbols_parser.read_line(line)
elif self._section == "LINES":
self._lines_section(line)

View file

@ -0,0 +1,153 @@
from dataclasses import dataclass, field
import logging
import re
from re import Match
from typing import NamedTuple, Optional
logger = logging.getLogger(__name__)
class StackOrRegisterSymbol(NamedTuple):
symbol_type: str
location: str
"""Should always be set/converted to lowercase."""
data_type: str
name: str
# S_GPROC32 = functions
@dataclass
class SymbolsEntry:
# pylint: disable=too-many-instance-attributes
type: str
section: int
offset: int
size: int
func_type: str
name: str
stack_symbols: list[StackOrRegisterSymbol] = field(default_factory=list)
frame_pointer_present: bool = False
addr: Optional[int] = None # Absolute address. Will be set later, if at all
class CvdumpSymbolsParser:
_symbol_line_generic_regex = re.compile(
r"\(\w+\)\s+(?P<symbol_type>[^\s:]+)(?::\s+(?P<second_part>\S.*))?|(?::)$"
)
"""
Parses the first part, e.g. `(00008C) S_GPROC32`, and splits off the second part after the colon (if it exists).
There are three cases:
- no colon, e.g. `(000350) S_END`
- colon but no data, e.g. `(000370) S_COMPILE:`
- colon and data, e.g. `(000304) S_REGISTER: esi, Type: 0x1E14, this``
"""
_symbol_line_function_regex = re.compile(
r"\[(?P<section>\w{4}):(?P<offset>\w{8})\], Cb: (?P<size>\w+), Type:\s+(?P<func_type>[^\s,]+), (?P<name>.+)"
)
"""
Parses the second part of a function symbol, e.g.
`[0001:00034E90], Cb: 00000007, Type: 0x1024, ViewROI::IntrinsicImportance`
"""
# the second part of e.g.
_stack_register_symbol_regex = re.compile(
r"(?P<location>\S+), Type:\s+(?P<data_type>[\w()]+), (?P<name>.+)$"
)
"""
Parses the second part of a stack or register symbol, e.g.
`esi, Type: 0x1E14, this`
"""
_debug_start_end_regex = re.compile(
r"^\s*Debug start: (?P<debug_start>\w+), Debug end: (?P<debug_end>\w+)$"
)
_parent_end_next_regex = re.compile(
r"\s*Parent: (?P<parent_addr>\w+), End: (?P<end_addr>\w+), Next: (?P<next_addr>\w+)$"
)
_flags_frame_pointer_regex = re.compile(r"\s*Flags: Frame Ptr Present$")
_register_stack_symbols = ["S_BPREL32", "S_REGISTER"]
# List the unhandled types so we can check exhaustiveness
_unhandled_symbols = [
"S_COMPILE",
"S_OBJNAME",
"S_THUNK32",
"S_LABEL32",
"S_LDATA32",
"S_LPROC32",
"S_UDT",
]
"""Parser for cvdump output, SYMBOLS section."""
def __init__(self):
self.symbols: list[SymbolsEntry] = []
self.current_function: Optional[SymbolsEntry] = None
def read_line(self, line: str):
if (match := self._symbol_line_generic_regex.match(line)) is not None:
self._parse_generic_case(line, match)
elif (match := self._parent_end_next_regex.match(line)) is not None:
# We do not need this info at the moment, might be useful in the future
pass
elif (match := self._debug_start_end_regex.match(line)) is not None:
# We do not need this info at the moment, might be useful in the future
pass
elif (match := self._flags_frame_pointer_regex.match(line)) is not None:
if self.current_function is None:
logger.error(
"Found a `Flags: Frame Ptr Present` but self.current_function is None"
)
return
self.current_function.frame_pointer_present = True
else:
# Most of these are either `** Module: [...]` or data we do not care about
logger.debug("Unhandled line: %s", line[:-1])
def _parse_generic_case(self, line, line_match: Match[str]):
symbol_type: str = line_match.group("symbol_type")
second_part: Optional[str] = line_match.group("second_part")
if symbol_type == "S_GPROC32":
assert second_part is not None
if (match := self._symbol_line_function_regex.match(second_part)) is None:
logger.error("Invalid function symbol: %s", line[:-1])
return
self.current_function = SymbolsEntry(
type=symbol_type,
section=int(match.group("section"), 16),
offset=int(match.group("offset"), 16),
size=int(match.group("size"), 16),
func_type=match.group("func_type"),
name=match.group("name"),
)
self.symbols.append(self.current_function)
elif symbol_type in self._register_stack_symbols:
assert second_part is not None
if self.current_function is None:
logger.error("Found stack/register outside of function: %s", line[:-1])
return
if (match := self._stack_register_symbol_regex.match(second_part)) is None:
logger.error("Invalid stack/register symbol: %s", line[:-1])
return
new_symbol = StackOrRegisterSymbol(
symbol_type=symbol_type,
location=match.group("location").lower(),
data_type=match.group("data_type"),
name=match.group("name"),
)
self.current_function.stack_symbols.append(new_symbol)
elif symbol_type == "S_END":
self.current_function = None
elif symbol_type in self._unhandled_symbols:
return
else:
logger.error("Unhandled symbol type: %s", line)

View file

@ -1,5 +1,9 @@
import re
from typing import Dict, List, NamedTuple, Optional
import logging
from typing import Any, Dict, List, NamedTuple, Optional
logger = logging.getLogger(__name__)
class CvdumpTypeError(Exception):
@ -42,7 +46,7 @@ def is_pointer(self) -> bool:
class TypeInfo(NamedTuple):
key: str
size: int
size: Optional[int]
name: Optional[str] = None
members: Optional[List[FieldListItem]] = None
@ -156,6 +160,10 @@ class CvdumpTypesParser:
# LF_FIELDLIST member name (2/2)
MEMBER_RE = re.compile(r"^\s+member name = '(?P<name>.*)'$")
LF_FIELDLIST_ENUMERATE = re.compile(
r"^\s+list\[\d+\] = LF_ENUMERATE,.*value = (?P<value>\d+), name = '(?P<name>[^']+)'$"
)
# LF_ARRAY element type
ARRAY_ELEMENT_RE = re.compile(r"^\s+Element type = (?P<type>.*)")
@ -169,12 +177,53 @@ class CvdumpTypesParser:
# LF_CLASS/LF_STRUCTURE name and other info
CLASS_NAME_RE = re.compile(
r"^\s+Size = (?P<size>\d+), class name = (?P<name>.+), UDT\((?P<udt>0x\w+)\)"
r"^\s+Size = (?P<size>\d+), class name = (?P<name>(?:[^,]|,\S)+)(?:, UDT\((?P<udt>0x\w+)\))?"
)
# LF_MODIFIER, type being modified
MODIFIES_RE = re.compile(r".*modifies type (?P<type>.*)$")
# LF_ARGLIST number of entries
LF_ARGLIST_ARGCOUNT = re.compile(r".*argument count = (?P<argcount>\d+)$")
# LF_ARGLIST list entry
LF_ARGLIST_ENTRY = re.compile(
r"^\s+list\[(?P<index>\d+)\] = (?P<arg_type>[\w()]+)$"
)
# LF_POINTER element
LF_POINTER_ELEMENT = re.compile(r"^\s+Element type : (?P<element_type>.+)$")
# LF_MFUNCTION attribute key-value pairs
LF_MFUNCTION_ATTRIBUTES = [
re.compile(r"\s*Return type = (?P<return_type>[\w()]+)$"),
re.compile(r"\s*Class type = (?P<class_type>[\w()]+)$"),
re.compile(r"\s*This type = (?P<this_type>[\w()]+)$"),
# Call type may contain whitespace
re.compile(r"\s*Call type = (?P<call_type>[\w()\s]+)$"),
re.compile(r"\s*Parms = (?P<num_params>[\w()]+)$"), # LF_MFUNCTION only
re.compile(r"\s*# Parms = (?P<num_params>[\w()]+)$"), # LF_PROCEDURE only
re.compile(r"\s*Arg list type = (?P<arg_list_type>[\w()]+)$"),
re.compile(
r"\s*This adjust = (?P<this_adjust>[\w()]+)$"
), # TODO: figure out the meaning
re.compile(
r"\s*Func attr = (?P<func_attr>[\w()]+)$"
), # Only for completeness, is always `none`
]
LF_ENUM_ATTRIBUTES = [
re.compile(r"^\s*# members = (?P<num_members>\d+)$"),
re.compile(r"^\s*enum name = (?P<name>.+)$"),
]
LF_ENUM_TYPES = re.compile(
r"^\s*type = (?P<underlying_type>\S+) field list type (?P<field_type>0x\w{4})$"
)
LF_ENUM_UDT = re.compile(r"^\s*UDT\((?P<udt>0x\w+)\)$")
LF_UNION_LINE = re.compile(
r"^.*field list type (?P<field_type>0x\w+),.*Size = (?P<size>\d+)\s*,class name = (?P<name>(?:[^,]|,\S)+),\s.*UDT\((?P<udt>0x\w+)\)$"
)
MODES_OF_INTEREST = {
"LF_ARRAY",
"LF_CLASS",
@ -183,12 +232,16 @@ class CvdumpTypesParser:
"LF_MODIFIER",
"LF_POINTER",
"LF_STRUCTURE",
"LF_ARGLIST",
"LF_MFUNCTION",
"LF_PROCEDURE",
"LF_UNION",
}
def __init__(self) -> None:
self.mode: Optional[str] = None
self.last_key = ""
self.keys = {}
self.keys: Dict[str, Dict[str, Any]] = {}
def _new_type(self):
"""Prepare a new dict for the type we just parsed.
@ -211,13 +264,20 @@ def _set_member_name(self, name: str):
obj = self.keys[self.last_key]
obj["members"][-1]["name"] = name
def _get_field_list(self, type_obj: Dict) -> List[FieldListItem]:
def _add_variant(self, name: str, value: int):
obj = self.keys[self.last_key]
if "variants" not in obj:
obj["variants"] = []
variants: list[dict[str, Any]] = obj["variants"]
variants.append({"name": name, "value": value})
def _get_field_list(self, type_obj: Dict[str, Any]) -> List[FieldListItem]:
"""Return the field list for the given LF_CLASS/LF_STRUCTURE reference"""
if type_obj.get("type") == "LF_FIELDLIST":
field_obj = type_obj
else:
field_list_type = type_obj.get("field_list_type")
field_list_type = type_obj["field_list_type"]
field_obj = self.keys[field_list_type]
members: List[FieldListItem] = []
@ -253,6 +313,9 @@ def _mock_array_members(self, type_obj: Dict) -> List[FieldListItem]:
raise CvdumpIntegrityError("No array element type")
array_element_size = self.get(array_type).size
assert (
array_element_size is not None
), "Encountered an array whose type has no size"
n_elements = type_obj["size"] // array_element_size
@ -285,7 +348,10 @@ def get(self, type_key: str) -> TypeInfo:
# These type references are just a wrapper around a scalar
if obj.get("type") == "LF_ENUM":
return self.get("T_INT4")
underlying_type = obj.get("underlying_type")
if underlying_type is None:
raise CvdumpKeyError(f"Missing 'underlying_type' in {obj}")
return self.get(underlying_type)
if obj.get("type") == "LF_POINTER":
return self.get("T_32PVOID")
@ -350,6 +416,9 @@ def get_scalars_gapless(self, type_key: str) -> List[ScalarType]:
obj = self.get(type_key)
total_size = obj.size
assert (
total_size is not None
), "Called get_scalar_gapless() on a type without size"
scalars = self.get_scalars(type_key)
@ -383,6 +452,11 @@ def get_format_string(self, type_key: str) -> str:
return member_list_to_struct_string(members)
def read_line(self, line: str):
if line.endswith("\n"):
line = line[:-1]
if len(line) == 0:
return
if (match := self.INDEX_RE.match(line)) is not None:
type_ = match.group(2)
if type_ not in self.MODES_OF_INTEREST:
@ -393,6 +467,12 @@ def read_line(self, line: str):
self.last_key = match.group(1)
self.mode = type_
self._new_type()
if type_ == "LF_ARGLIST":
submatch = self.LF_ARGLIST_ARGCOUNT.match(line)
assert submatch is not None
self.keys[self.last_key]["argcount"] = int(submatch.group("argcount"))
# TODO: This should be validated in another pass
return
if self.mode is None:
@ -413,6 +493,31 @@ def read_line(self, line: str):
self._set("size", int(match.group("length")))
elif self.mode == "LF_FIELDLIST":
self.read_fieldlist_line(line)
elif self.mode == "LF_ARGLIST":
self.read_arglist_line(line)
elif self.mode in ["LF_MFUNCTION", "LF_PROCEDURE"]:
self.read_mfunction_line(line)
elif self.mode in ["LF_CLASS", "LF_STRUCTURE"]:
self.read_class_or_struct_line(line)
elif self.mode == "LF_POINTER":
self.read_pointer_line(line)
elif self.mode == "LF_ENUM":
self.read_enum_line(line)
elif self.mode == "LF_UNION":
self.read_union_line(line)
else:
# Check for exhaustiveness
logger.error("Unhandled data in mode: %s", self.mode)
def read_fieldlist_line(self, line: str):
# If this class has a vtable, create a mock member at offset 0
if (match := self.VTABLE_RE.match(line)) is not None:
# For our purposes, any pointer type will do
@ -433,7 +538,10 @@ def read_line(self, line: str):
elif (match := self.MEMBER_RE.match(line)) is not None:
self._set_member_name(match.group("name"))
else: # LF_CLASS or LF_STRUCTURE
elif (match := self.LF_FIELDLIST_ENUMERATE.match(line)) is not None:
self._add_variant(match.group("name"), int(match.group("value")))
def read_class_or_struct_line(self, line: str):
# Match the reference to the associated LF_FIELDLIST
if (match := self.CLASS_FIELD_RE.match(line)) is not None:
if match.group("field_type") == "0x0000":
@ -444,10 +552,111 @@ def read_line(self, line: str):
field_list_type = normalize_type_id(match.group("field_type"))
self._set("field_list_type", field_list_type)
elif line.lstrip().startswith("Derivation list type"):
# We do not care about the second line, but we still match it so we see an error
# when another line fails to match
pass
elif (match := self.CLASS_NAME_RE.match(line)) is not None:
# Last line has the vital information.
# If this is a FORWARD REF, we need to follow the UDT pointer
# to get the actual class details.
elif (match := self.CLASS_NAME_RE.match(line)) is not None:
self._set("name", match.group("name"))
self._set("udt", normalize_type_id(match.group("udt")))
udt = match.group("udt")
if udt is not None:
self._set("udt", normalize_type_id(udt))
self._set("size", int(match.group("size")))
else:
logger.error("Unmatched line in class: %s", line[:-1])
def read_arglist_line(self, line: str):
if (match := self.LF_ARGLIST_ENTRY.match(line)) is not None:
obj = self.keys[self.last_key]
arglist: list = obj.setdefault("args", [])
assert int(match.group("index")) == len(
arglist
), "Argument list out of sync"
arglist.append(match.group("arg_type"))
else:
logger.error("Unmatched line in arglist: %s", line[:-1])
def read_pointer_line(self, line):
if (match := self.LF_POINTER_ELEMENT.match(line)) is not None:
self._set("element_type", match.group("element_type"))
else:
stripped_line = line.strip()
# We don't parse these lines, but we still want to check for exhaustiveness
# in case we missed some relevant data
if not any(
stripped_line.startswith(prefix)
for prefix in ["Pointer", "const Pointer", "L-value", "volatile"]
):
logger.error("Unrecognized pointer attribute: %s", line[:-1])
def read_mfunction_line(self, line: str):
"""
The layout is not consistent, so we want to be as robust as possible here.
- Example 1:
Return type = T_LONG(0012), Call type = C Near
Func attr = none
- Example 2:
Return type = T_CHAR(0010), Class type = 0x101A, This type = 0x101B,
Call type = ThisCall, Func attr = none
"""
obj = self.keys[self.last_key]
key_value_pairs = line.split(",")
for pair in key_value_pairs:
if pair.isspace():
continue
obj |= self.parse_function_attribute(pair)
def parse_function_attribute(self, pair: str) -> dict[str, str]:
for attribute_regex in self.LF_MFUNCTION_ATTRIBUTES:
if (match := attribute_regex.match(pair)) is not None:
return match.groupdict()
logger.error("Unknown attribute in function: %s", pair)
return {}
def read_enum_line(self, line: str):
obj = self.keys[self.last_key]
# We need special comma handling because commas may appear in the name.
# Splitting by "," yields the wrong result.
enum_attributes = line.split(", ")
for pair in enum_attributes:
if pair.endswith(","):
pair = pair[:-1]
if pair.isspace():
continue
obj |= self.parse_enum_attribute(pair)
def parse_enum_attribute(self, attribute: str) -> dict[str, Any]:
for attribute_regex in self.LF_ENUM_ATTRIBUTES:
if (match := attribute_regex.match(attribute)) is not None:
return match.groupdict()
if attribute == "NESTED":
return {"is_nested": True}
if attribute == "FORWARD REF":
return {"is_forward_ref": True}
if attribute.startswith("UDT"):
match = self.LF_ENUM_UDT.match(attribute)
assert match is not None
return {"udt": normalize_type_id(match.group("udt"))}
if (match := self.LF_ENUM_TYPES.match(attribute)) is not None:
result = match.groupdict()
result["underlying_type"] = normalize_type_id(result["underlying_type"])
return result
logger.error("Unknown attribute in enum: %s", attribute)
return {}
def read_union_line(self, line: str):
"""This is a rather barebones handler, only parsing the size"""
if (match := self.LF_UNION_LINE.match(line)) is None:
raise AssertionError(f"Unhandled in union: {line}")
self._set("name", match.group("name"))
if match.group("field_type") == "0x0000":
self._set("is_forward_ref", True)
self._set("size", int(match.group("size")))
self._set("udt", normalize_type_id(match.group("udt")))

View file

@ -9,6 +9,21 @@
)
TEST_LINES = """
0x1018 : Length = 18, Leaf = 0x1201 LF_ARGLIST argument count = 3
list[0] = 0x100D
list[1] = 0x1016
list[2] = 0x1017
0x1019 : Length = 14, Leaf = 0x1008 LF_PROCEDURE
Return type = T_LONG(0012), Call type = C Near
Func attr = none
# Parms = 3, Arg list type = 0x1018
0x101e : Length = 26, Leaf = 0x1009 LF_MFUNCTION
Return type = T_CHAR(0010), Class type = 0x101A, This type = 0x101B,
Call type = ThisCall, Func attr = none
Parms = 2, Arg list type = 0x101d, This adjust = 0
0x1028 : Length = 10, Leaf = 0x1001 LF_MODIFIER
const, modifies type T_REAL32(0040)
@ -123,6 +138,12 @@
length = 440
Name =
0x2339 : Length = 26, Leaf = 0x1506 LF_UNION
# members = 0, field list type 0x0000, FORWARD REF, Size = 0 ,class name = FlagBitfield, UDT(0x00002e85)
0x2e85 : Length = 26, Leaf = 0x1506 LF_UNION
# members = 8, field list type 0x2e84, Size = 1 ,class name = FlagBitfield, UDT(0x00002e85)
0x2a75 : Length = 98, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_MEMBER, public, type = T_32PRCHAR(0470), offset = 0
member name = 'm_name'
@ -160,6 +181,11 @@
Derivation list type 0x0000, VT shape type 0x20fb
Size = 36, class name = MxVariable, UDT(0x00004041)
0x3c45 : Length = 50, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_ENUMERATE, public, value = 1, name = 'c_read'
list[1] = LF_ENUMERATE, public, value = 2, name = 'c_write'
list[2] = LF_ENUMERATE, public, value = 4, name = 'c_text'
0x3cc2 : Length = 38, Leaf = 0x1507 LF_ENUM
# members = 64, type = T_INT4(0074) field list type 0x3cc1
NESTED, enum name = JukeBox::JukeBoxScript, UDT(0x00003cc2)
@ -235,7 +261,7 @@ def types_parser_fixture():
return parser
def test_basic_parsing(parser):
def test_basic_parsing(parser: CvdumpTypesParser):
obj = parser.keys["0x4db6"]
assert obj["type"] == "LF_CLASS"
assert obj["name"] == "MxString"
@ -244,7 +270,7 @@ def test_basic_parsing(parser):
assert len(parser.keys["0x4db5"]["members"]) == 2
def test_scalar_types(parser):
def test_scalar_types(parser: CvdumpTypesParser):
"""Full tests on the scalar_* methods are in another file.
Here we are just testing the passthrough of the "T_" types."""
assert parser.get("T_CHAR").name is None
@ -254,7 +280,7 @@ def test_scalar_types(parser):
assert parser.get("T_32PVOID").size == 4
def test_resolve_forward_ref(parser):
def test_resolve_forward_ref(parser: CvdumpTypesParser):
# Non-forward ref
assert parser.get("0x22d5").name == "MxVariable"
# Forward ref
@ -262,7 +288,7 @@ def test_resolve_forward_ref(parser):
assert parser.get("0x14db").size == 16
def test_members(parser):
def test_members(parser: CvdumpTypesParser):
"""Return the list of items to compare for a given complex type.
If the class has a superclass, add those members too."""
# MxCore field list
@ -284,7 +310,7 @@ def test_members(parser):
]
def test_members_recursive(parser):
def test_members_recursive(parser: CvdumpTypesParser):
"""Make sure that we unwrap the dependency tree correctly."""
# MxVariable field list
assert parser.get_scalars("0x22d4") == [
@ -300,7 +326,7 @@ def test_members_recursive(parser):
]
def test_struct(parser):
def test_struct(parser: CvdumpTypesParser):
"""Basic test for converting type into struct.unpack format string."""
# MxCore: vftable and uint32. The vftable pointer is read as uint32.
assert parser.get_format_string("0x4060") == "<LL"
@ -312,7 +338,7 @@ def test_struct(parser):
assert parser.get_format_string("0x1214") == "<llll"
def test_struct_padding(parser):
def test_struct_padding(parser: CvdumpTypesParser):
"""For data comparison purposes, make sure we have no gaps in the
list of scalar types. Any gap is filled by an unsigned char."""
@ -326,7 +352,7 @@ def test_struct_padding(parser):
assert len(parser.get_scalars_gapless("0x22d5")) == 13
def test_struct_format_string(parser):
def test_struct_format_string(parser: CvdumpTypesParser):
"""Generate the struct.unpack format string using the
list of scalars with padding filled in."""
# MxString, padded to 16 bytes.
@ -336,7 +362,7 @@ def test_struct_format_string(parser):
assert parser.get_format_string("0x22d5") == "<LLLLHBBLLLHBB"
def test_array(parser):
def test_array(parser: CvdumpTypesParser):
"""LF_ARRAY members are created dynamically based on the
total array size and the size of one element."""
# unsigned char[8]
@ -360,7 +386,7 @@ def test_array(parser):
]
def test_2d_array(parser):
def test_2d_array(parser: CvdumpTypesParser):
"""Make sure 2d array elements are named as we expect."""
# float[4][4]
float_array = parser.get_scalars("0x103c")
@ -371,7 +397,7 @@ def test_2d_array(parser):
assert float_array[-1] == (60, "[3][3]", "T_REAL32")
def test_enum(parser):
def test_enum(parser: CvdumpTypesParser):
"""LF_ENUM should equal 4-byte int"""
assert parser.get("0x3cc2").size == 4
assert parser.get_scalars("0x3cc2") == [(0, None, "T_INT4")]
@ -382,7 +408,7 @@ def test_enum(parser):
assert enum_array[0].size == 4
def test_lf_pointer(parser):
def test_lf_pointer(parser: CvdumpTypesParser):
"""LF_POINTER is just a wrapper for scalar pointer type"""
assert parser.get("0x3fab").size == 4
# assert parser.get("0x3fab").is_pointer is True # TODO: ?
@ -390,7 +416,7 @@ def test_lf_pointer(parser):
assert parser.get_scalars("0x3fab") == [(0, None, "T_32PVOID")]
def test_key_not_exist(parser):
def test_key_not_exist(parser: CvdumpTypesParser):
"""Accessing a non-existent type id should raise our exception"""
with pytest.raises(CvdumpKeyError):
parser.get("0xbeef")
@ -399,7 +425,7 @@ def test_key_not_exist(parser):
parser.get_scalars("0xbeef")
def test_broken_forward_ref(parser):
def test_broken_forward_ref(parser: CvdumpTypesParser):
"""Raise an exception if we cannot follow a forward reference"""
# Verify forward reference on MxCore
parser.get("0x1220")
@ -412,7 +438,7 @@ def test_broken_forward_ref(parser):
parser.get("0x1220")
def test_null_forward_ref(parser):
def test_null_forward_ref(parser: CvdumpTypesParser):
"""If the forward ref object is invalid and has no forward ref id,
raise an exception."""
# Test MxString forward reference
@ -426,7 +452,7 @@ def test_null_forward_ref(parser):
parser.get("0x14db")
def test_broken_array_element_ref(parser):
def test_broken_array_element_ref(parser: CvdumpTypesParser):
# Test LF_ARRAY of ROIColorAlias
parser.get("0x19b1")
@ -438,7 +464,7 @@ def test_broken_array_element_ref(parser):
parser.get("0x19b1")
def test_lf_modifier(parser):
def test_lf_modifier(parser: CvdumpTypesParser):
"""Is this an alias for another type?"""
# Modifies float
assert parser.get("0x1028").size == 4
@ -449,7 +475,7 @@ def test_lf_modifier(parser):
assert mxrect == parser.get_scalars("0x11f2")
def test_union_members(parser):
def test_union_members(parser: CvdumpTypesParser):
"""If there is a union somewhere in our dependency list, we can
expect to see duplicated member offsets and names. This is ok for
the TypeInfo tuple, but the list of ScalarType items should have
@ -457,9 +483,71 @@ def test_union_members(parser):
# D3DVector type with duplicated offsets
d3dvector = parser.get("0x10e1")
assert d3dvector.members is not None
assert len(d3dvector.members) == 6
assert len([m for m in d3dvector.members if m.offset == 0]) == 2
# Deduplicated comparison list
vector_items = parser.get_scalars("0x10e1")
assert len(vector_items) == 3
def test_arglist(parser: CvdumpTypesParser):
arglist = parser.keys["0x1018"]
assert arglist["argcount"] == 3
assert arglist["args"] == ["0x100D", "0x1016", "0x1017"]
def test_procedure(parser: CvdumpTypesParser):
procedure = parser.keys["0x1019"]
assert procedure == {
"type": "LF_PROCEDURE",
"return_type": "T_LONG(0012)",
"call_type": "C Near",
"func_attr": "none",
"num_params": "3",
"arg_list_type": "0x1018",
}
def test_mfunction(parser: CvdumpTypesParser):
mfunction = parser.keys["0x101e"]
assert mfunction == {
"type": "LF_MFUNCTION",
"return_type": "T_CHAR(0010)",
"class_type": "0x101A",
"this_type": "0x101B",
"call_type": "ThisCall",
"func_attr": "none",
"num_params": "2",
"arg_list_type": "0x101d",
"this_adjust": "0",
}
def test_union_forward_ref(parser: CvdumpTypesParser):
union = parser.keys["0x2339"]
assert union["is_forward_ref"] is True
assert union["udt"] == "0x2e85"
def test_union(parser: CvdumpTypesParser):
union = parser.keys["0x2e85"]
assert union == {
"type": "LF_UNION",
"name": "FlagBitfield",
"size": 1,
"udt": "0x2e85",
}
def test_fieldlist_enumerate(parser: CvdumpTypesParser):
fieldlist_enum = parser.keys["0x3c45"]
assert fieldlist_enum == {
"type": "LF_FIELDLIST",
"variants": [
{"name": "c_read", "value": 1},
{"name": "c_write", "value": 2},
{"name": "c_text", "value": 4},
],
}