Use reccmp as a python requirement (#1116)

* Use reccmp as a python requirement

* Add BETA10 to reccmp-project.yml
This commit is contained in:
Anonymous Maarten 2024-10-26 14:57:47 +02:00 committed by GitHub
parent c38e157fdb
commit 0cb753e523
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
101 changed files with 143 additions and 14791 deletions

View file

@ -17,10 +17,14 @@ jobs:
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install python libraries
run: |
python -m pip install -r tools/requirements.txt
pip install -r tools/requirements.txt
- name: Run decomplint.py
run: |
tools/decomplint/decomplint.py ${{ matrix.who }} --module ${{ matrix.who }} --warnfail
reccmp-decomplint ${{ matrix.who }} --module ${{ matrix.who }} --warnfail

View file

@ -107,6 +107,10 @@ jobs:
steps:
- uses: actions/checkout@master
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- uses: actions/download-artifact@master
with:
name: Win32
@ -125,12 +129,17 @@ jobs:
run: |
pip install -r tools/requirements.txt
- name: Detect binaries
run: |
reccmp-project detect --what original --search-path legobin
reccmp-project detect --what recompiled --search-path build
- name: Summarize Accuracy
shell: bash
run: |
python3 tools/reccmp/reccmp.py -S CONFIGPROGRESS.SVG --svg-icon tools/reccmp/config.png -H CONFIGPROGRESS.HTML legobin/CONFIG.EXE build/CONFIG.EXE build/CONFIG.PDB . | tee CONFIGPROGRESS.TXT
python3 tools/reccmp/reccmp.py -S ISLEPROGRESS.SVG --svg-icon tools/reccmp/isle.png -H ISLEPROGRESS.HTML legobin/ISLE.EXE build/ISLE.EXE build/ISLE.PDB . | tee ISLEPROGRESS.TXT
python3 tools/reccmp/reccmp.py -S LEGO1PROGRESS.SVG -T 4252 --svg-icon tools/reccmp/lego1.png -H LEGO1PROGRESS.HTML legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB . | tee LEGO1PROGRESS.TXT
reccmp-reccmp -S CONFIGPROGRESS.SVG --svg-icon assets/config.png --target CONFIG | tee CONFIGPROGRESS.TXT
reccmp-reccmp -S ISLEPROGRESS.SVG --svg-icon assets/isle.png --target ISLE | tee ISLEPROGRESS.TXT
reccmp-reccmp -S LEGO1PROGRESS.SVG -T 4252 --svg-icon assets/lego1.png --target LEGO1 | tee LEGO1PROGRESS.TXT
- name: Compare Accuracy With Current Master
shell: bash
@ -147,21 +156,21 @@ jobs:
- name: Test Exports
shell: bash
run: |
tools/verexp/verexp.py legobin/LEGO1.DLL build/LEGO1.DLL
reccmp-verexp --target LEGO1
- name: Check Vtables
shell: bash
run: |
python3 tools/vtable/vtable.py legobin/CONFIG.EXE build/CONFIG.EXE build/CONFIG.PDB .
python3 tools/vtable/vtable.py legobin/ISLE.EXE build/ISLE.EXE build/ISLE.PDB .
python3 tools/vtable/vtable.py legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .
reccmp-vtable --target CONFIG
reccmp-vtable --target ISLE
reccmp-vtable --target LEGO1
- name: Check Variables
shell: bash
run: |
python3 tools/datacmp.py legobin/CONFIG.EXE build/CONFIG.EXE build/CONFIG.PDB .
python3 tools/datacmp.py legobin/ISLE.EXE build/ISLE.EXE build/ISLE.PDB .
python3 tools/datacmp.py legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .
reccmp-datacmp --target CONFIG
reccmp-datacmp --target ISLE
reccmp-datacmp --target LEGO1
- name: Upload Artifact
uses: actions/upload-artifact@master

View file

@ -1,37 +0,0 @@
name: Format
on: [push, pull_request]
jobs:
clang-format:
name: 'C++'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Run clang-format
run: |
find CONFIG LEGO1 ISLE -iname '*.h' -o -iname '*.cpp' | xargs \
pipx run "clang-format>=17,<18" \
--style=file \
-i
git diff --exit-code
python-format:
name: 'Python'
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install python libraries
shell: bash
run: |
pip install black==23.* pylint==3.2.7 pytest==7.* -r tools/requirements.txt
- name: Run pylint and black
shell: bash
run: |
pylint tools --ignore=build,ncc
black --check tools --exclude=ncc

View file

@ -15,6 +15,10 @@ jobs:
with:
version: "16"
- uses: actions/setup-python@v5
with:
python-version: '3.12'
- name: Install python libraries
run: |
pip install -r tools/requirements.txt

View file

@ -1,60 +0,0 @@
name: Test
on: [push, pull_request]
jobs:
fetch-deps:
name: Download original binaries
uses: ./.github/workflows/legobin.yml
pytest-win:
name: 'Python Windows'
runs-on: windows-latest
needs: fetch-deps
steps:
- uses: actions/checkout@v4
- name: Restore cached original binaries
id: cache-original-binaries
uses: actions/cache/restore@v3
with:
enableCrossOsArchive: true
path: legobin
key: legobin
- name: Install python libraries
shell: bash
run: |
pip install pytest -r tools/requirements.txt
- name: Run python unit tests (Windows)
shell: bash
run: |
pytest tools/isledecomp --lego1=legobin/LEGO1.DLL
pytest-ubuntu:
name: 'Python Linux'
runs-on: ubuntu-latest
needs: fetch-deps
steps:
- uses: actions/checkout@v4
- name: Restore cached original binaries
id: cache-original-binaries
uses: actions/cache/restore@v3
with:
enableCrossOsArchive: true
path: legobin
key: legobin
- name: Install python libraries
shell: bash
run: |
pip install pytest -r tools/requirements.txt
- name: Run python unit tests (Ubuntu)
shell: bash
run: |
pytest tools/isledecomp --lego1=legobin/LEGO1.DLL

2
.gitignore vendored
View file

@ -1,3 +1,5 @@
reccmp-user.yml
reccmp-build.yml
Debug/
Release/
*.ncb

View file

@ -8,6 +8,7 @@ project(isle CXX)
include(CheckCXXSourceCompiles)
include(CMakeDependentOption)
include(CMakePushCheckState)
include("${CMAKE_CURRENT_LIST_DIR}/cmake/reccmp.cmake")
set(CMAKE_EXPORT_COMPILE_COMMANDS TRUE)
option(ENABLE_CLANG_TIDY "Enable clang-tidy")
@ -405,6 +406,7 @@ add_library(lego1 SHARED
LEGO1/main.cpp
LEGO1/modeldb/modeldb.cpp
)
reccmp_add_target(lego1 ID LEGO1)
register_lego1_target(lego1)
if (MINGW)
@ -447,6 +449,7 @@ if (ISLE_BUILD_APP)
ISLE/res/isle.rc
ISLE/isleapp.cpp
)
reccmp_add_target(isle ID ISLE)
target_compile_definitions(isle PRIVATE ISLE_APP)
@ -477,6 +480,7 @@ if (ISLE_BUILD_CONFIG)
CONFIG/StdAfx.cpp
CONFIG/res/config.rc
)
reccmp_add_target(config ID CONFIG)
target_compile_definitions(config PRIVATE _AFXDLL MXDIRECTX_FOR_CONFIG)
target_include_directories(config PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/util" "${CMAKE_CURRENT_SOURCE_DIR}/LEGO1")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS 14)
@ -603,3 +607,5 @@ if(EXISTS "${CLANGFORMAT_BIN}")
endif()
endif()
endif()
reccmp_configure()

View file

@ -62,3 +62,4 @@ Right click on `LEGO1.DLL`, select `Properties`, and switch to the `Details` tab
* ISLE.EXE `md5: f6da12249e03eed1c74810cd23beb9f5`
* LEGO1.DLL `md5: 4e2f6d969ea2ef8655ba3fc221a0c8fe`
* CONFIG.DLL `md5: 92d958a64a273662c591c88b09100f4a`

View file

Before

Width:  |  Height:  |  Size: 1.4 KiB

After

Width:  |  Height:  |  Size: 1.4 KiB

View file

Before

Width:  |  Height:  |  Size: 5.3 KiB

After

Width:  |  Height:  |  Size: 5.3 KiB

View file

Before

Width:  |  Height:  |  Size: 5.5 KiB

After

Width:  |  Height:  |  Size: 5.5 KiB

58
cmake/reccmp.cmake Normal file
View file

@ -0,0 +1,58 @@
function(reccmp_find_project RESULT)
set(curdir "${CMAKE_CURRENT_SOURCE_DIR}")
while(1)
if(EXISTS "${curdir}/reccmp-project.yml")
break()
endif()
get_filename_component(nextdir "${curdir}" DIRECTORY)
if(nextdir STREQUAL curdir)
set(curdir "${RESULT}-NOTFOUND")
break()
endif()
set(curdir "${nextdir}")
endwhile()
set("${RESULT}" "${curdir}" PARENT_SCOPE)
endfunction()
function(reccmp_add_target TARGET)
cmake_parse_arguments(ARGS "" "ID" "" ${ARGN})
if(NOT ARGS_ID)
message(FATAL_ERROR "Missing ID argument")
endif()
set_property(TARGET ${TARGET} PROPERTY INTERFACE_RECCMP_ID "${ARGS_ID}")
set_property(GLOBAL APPEND PROPERTY RECCMP_TARGETS ${TARGET})
endfunction()
function(reccmp_configure)
cmake_parse_arguments(ARGS "COPY_TO_SOURCE_FOLDER" "DIR" "" ${ARGN})
set(binary_dir "${CMAKE_BINARY_DIR}")
if(ARGS_DIR)
set(binary_dir "${ARGS_DIR}")
endif()
reccmp_find_project(reccmp_project_dir)
if(NOT reccmp_project_dir)
message(FATAL_ERROR "Cannot find reccmp-project.yml")
endif()
if(CMAKE_CONFIGURATION_TYPES)
set(outputdir "${binary_dir}/$<CONFIG>")
else()
set(outputdir "${binary_dir}")
endif()
set(build_yml_txt "project: '${reccmp_project_dir}'\ntargets:\n")
get_property(RECCMP_TARGETS GLOBAL PROPERTY RECCMP_TARGETS)
foreach(target ${RECCMP_TARGETS})
get_property(id TARGET "${target}" PROPERTY INTERFACE_RECCMP_ID)
string(APPEND build_yml_txt " ${id}:\n")
string(APPEND build_yml_txt " path: '$<TARGET_FILE:${target}>'\n")
if(WIN32 AND MSVC)
string(APPEND build_yml_txt " pdb: '$<TARGET_PDB_FILE:${target}>'\n")
endif()
endforeach()
file(GENERATE OUTPUT "${outputdir}/reccmp-build.yml" CONTENT "${build_yml_txt}")
if(ARGS_COPY_TO_SOURCE_FOLDER)
file(GENERATE OUTPUT "${CMAKE_SOURCE_DIR}/reccmp-build.yml" CONTENT "${build_yml_txt}" CONDITION $<CONFIG:Release>)
endif()
endfunction()

21
reccmp-project.yml Normal file
View file

@ -0,0 +1,21 @@
targets:
ISLE:
filename: ISLE.EXE
source-root: ISLE
hash:
sha256: 5cf57c284973fce9d14f5677a2e4435fd989c5e938970764d00c8932ed5128ca
LEGO1:
filename: LEGO1.DLL
source-root: LEGO1
hash:
sha256: 14645225bbe81212e9bc1919cd8a692b81b8622abb6561280d99b0fc4151ce17
CONFIG:
filename: CONFIG.EXE
source-root: CONFIG
hash:
sha256: 864766d024d78330fed5e1f6efb2faf815f1b1c3405713a9718059dc9a54e52c
BETA10:
filename: BETA10.DLL
source-root: LEGO1
hash:
sha256: d91435a40fa31f405fba33b03bd3bd40dcd4ca36ccf8ef6162c6c5ca0d7190e7

View file

@ -160,58 +160,42 @@ inline virtual const char* ClassName() const override // vtable+0x0c
Use `pip` to install the required packages to be able to use the Python tools found in this folder:
```
```sh
pip install -r tools/requirements.txt
```
Run the following command to allow reccmp to detect the original LEGO binaries:
```sh
reccmp-project detect --what original --search-path <paths-to-directories0containing-lego-binaries>
```
After building recompiled binaries, run the following command in this repository's root:
```sh
reccmp-project detect --what recompiled --search-path <paths-to-build-directories>
```
The example usages below assume that the current working directory is this repository's root and that the retail binaries have been copied to `./legobin`.
* [`decomplint`](/tools/decomplint): Checks the decompilation annotations (see above)
* e.g. `py -m tools.decomplint.decomplint --module LEGO1 LEGO1`
* [`isledecomp`](/tools/isledecomp): A library that implements a parser to identify the decompilation annotations (see above)
* `reccmp-decomplint`: Checks the decompilation annotations (see above)
* e.g. `reccmp-decomplint --module LEGO1 LEGO1`
* [`ncc`](/tools/ncc): Checks naming conventions based on a set of rules
* [`reccmp`](/tools/reccmp): Compares an original binary with a recompiled binary, provided a PDB file. For example:
* `reccmp-reccmp`: Compares an original binary with a recompiled binary, provided a PDB file. For example:
* Display the diff for a single function: `py -m tools.reccmp.reccmp --verbose 0x100ae1a0 legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* Generate an HTML report: `py -m tools.reccmp.reccmp --html output.html legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* Create a base file for diffs: `py -m tools.reccmp.reccmp --json base.json --silent legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* Diff against a base file: `py -m tools.reccmp.reccmp --diff base.json legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* [`stackcmp`](/tools/stackcmp): Compares the stack layout for a given function that almost matches.
* e.g. `py -m tools.stackcmp.stackcmp legobin/BETA10.DLL build_debug/LEGO1.DLL build_debug/LEGO1.pdb . 0x1007165d`
* [`roadmap`](/tools/roadmap): Compares symbol locations in an original binary with the same symbol locations of a recompiled binary
* [`verexp`](/tools/verexp): Verifies exports by comparing the exports of the original DLL and the recompiled DLL
* [`vtable`](/tools/vtable): Asserts virtual table correctness by comparing a recompiled binary with the original
* e.g. `py -m tools.vtable.vtable legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* [`datacmp.py`](/tools/datacmp.py): Compares global data found in the original with the recompiled version
* e.g. `py -m tools.datacmp legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* `reccmp-stackcmp`: Compares the stack layout for a given function that almost matches.
* e.g. `reccmp-stackcmp legobin/BETA10.DLL build_debug/LEGO1.DLL build_debug/LEGO1.pdb . 0x1007165d`
* `reccmp-roadmap`: Compares symbol locations in an original binary with the same symbol locations of a recompiled binary
* `reccmp-verexp`: Verifies exports by comparing the exports of the original DLL and the recompiled DLL
* `reccmp-vtable`: Asserts virtual table correctness by comparing a recompiled binary with the original
* e.g. `reccmp-vtable legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* `reccmp-datacmp`: Compares global data found in the original with the recompiled version
* e.g. `reccmp-datacmp legobin/LEGO1.DLL build/LEGO1.DLL build/LEGO1.PDB .`
* [`patch_c2.py`](/tools/patch_c2.py): Patches `C2.EXE` (part of MSVC 4.20) to get rid of a bugged warning
## Testing
`isledecomp` comes with a suite of tests. Install `pytest` and run it, passing in the directory:
```
pip install pytest
pytest tools/isledecomp/tests/
```
## Tool Development
In order to keep the Python code clean and consistent, we use `pylint` and `black`:
`pip install black pylint`
### Run pylint (ignores build and virtualenv)
`pylint tools/ --ignore=build,ncc`
### Check Python code formatting without rewriting files
`black --check tools/`
### Apply Python code formatting
`black tools/`
# Modules
The following is a list of all the modules found in the annotations (e.g. `// FUNCTION: [module] [address]`) and which binaries they refer to. See [this list of all known versions of the game](https://www.legoisland.org/wiki/LEGO_Island#Download).
@ -243,7 +227,7 @@ cmake <path-to-source> -G "NMake Makefiles" -DCMAKE_BUILD_TYPE=RelWithDebInfo -D
```
**TODO**: If you can figure out how to make a debug build with SmartHeap enabled, please add it here.
If you want to run scripts to compare your debug build to `BETA10` (e.g. `reccmp`), it is advisable to add a copy of `LEGO1D.DLL` to `/legobin` and rename it to `BETA10.DLL`.
If you want to run scripts to compare your debug build to `BETA10` (e.g. `reccmp-reccmp`), it is advisable to add a copy of `LEGO1D.DLL` to `/legobin` and rename it to `BETA10.DLL`.
### Finding matching functions

View file

@ -1,371 +0,0 @@
# (New) Data comparison.
import os
import argparse
import logging
from enum import Enum
from typing import Iterable, List, NamedTuple, Optional, Tuple
from struct import unpack
from isledecomp.compare import Compare as IsleCompare
from isledecomp.compare.db import MatchInfo
from isledecomp.cvdump import Cvdump
from isledecomp.cvdump.types import (
CvdumpKeyError,
CvdumpIntegrityError,
)
from isledecomp.bin import Bin as IsleBin
import colorama
colorama.just_fix_windows_console()
# Ignore all compare-db messages.
logging.getLogger("isledecomp.compare").addHandler(logging.NullHandler())
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Comparing data values.")
parser.add_argument(
"original", metavar="original-binary", help="The original binary"
)
parser.add_argument(
"recompiled", metavar="recompiled-binary", help="The recompiled binary"
)
parser.add_argument(
"pdb", metavar="recompiled-pdb", help="The PDB of the recompiled binary"
)
parser.add_argument(
"decomp_dir", metavar="decomp-dir", help="The decompiled source tree"
)
parser.add_argument(
"-v",
"--verbose",
action=argparse.BooleanOptionalAction,
default=False,
help="",
)
parser.add_argument(
"--no-color", "-n", action="store_true", help="Do not color the output"
)
parser.add_argument(
"--all",
"-a",
dest="show_all",
action="store_true",
help="Only show variables with a problem",
)
parser.add_argument(
"--print-rec-addr",
action="store_true",
help="Print addresses of recompiled functions too",
)
(args, _) = parser.parse_known_args()
if not os.path.isfile(args.original):
parser.error(f"Original binary {args.original} does not exist")
if not os.path.isfile(args.recompiled):
parser.error(f"Recompiled binary {args.recompiled} does not exist")
if not os.path.isfile(args.pdb):
parser.error(f"Symbols PDB {args.pdb} does not exist")
if not os.path.isdir(args.decomp_dir):
parser.error(f"Source directory {args.decomp_dir} does not exist")
return args
class CompareResult(Enum):
MATCH = 1
DIFF = 2
ERROR = 3
WARN = 4
class ComparedOffset(NamedTuple):
offset: int
# name is None for scalar types
name: Optional[str]
match: bool
values: Tuple[str, str]
class ComparisonItem(NamedTuple):
"""Each variable that was compared"""
orig_addr: int
recomp_addr: int
name: str
# The list of items that were compared.
# For a complex type, these are the members.
# For a scalar type, this is a list of size one.
# If we could not retrieve type information, this is
# a list of size one but without any specific type.
compared: List[ComparedOffset]
# If present, the error message from the types parser.
error: Optional[str] = None
# If true, there is no type specified for this variable. (i.e. non-public)
# In this case, we can only compare the raw bytes.
# This is different from the situation where a type id _is_ given, but
# we could not retrieve it for some reason. (This is an error.)
raw_only: bool = False
@property
def result(self) -> CompareResult:
if self.error is not None:
return CompareResult.ERROR
if all(c.match for c in self.compared):
return CompareResult.MATCH
# Prefer WARN for a diff without complete type information.
return CompareResult.WARN if self.raw_only else CompareResult.DIFF
def create_comparison_item(
var: MatchInfo,
compared: Optional[List[ComparedOffset]] = None,
error: Optional[str] = None,
raw_only: bool = False,
) -> ComparisonItem:
"""Helper to create the ComparisonItem from the fields in MatchInfo."""
if compared is None:
compared = []
return ComparisonItem(
orig_addr=var.orig_addr,
recomp_addr=var.recomp_addr,
name=var.name,
compared=compared,
error=error,
raw_only=raw_only,
)
def do_the_comparison(args: argparse.Namespace) -> Iterable[ComparisonItem]:
"""Run through each variable in our compare DB, then do the comparison
according to the variable's type. Emit the result."""
with IsleBin(args.original, find_str=True) as origfile, IsleBin(
args.recompiled
) as recompfile:
isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir)
# TODO: We don't currently retain the type information of each variable
# in our compare DB. To get those, we build this mini-lookup table that
# maps recomp addresses to their type.
# We still need to build the full compare DB though, because we may
# need the matched symbols to compare pointers (e.g. on strings)
mini_cvdump = Cvdump(args.pdb).globals().types().run()
recomp_type_reference = {
recompfile.get_abs_addr(g.section, g.offset): g.type
for g in mini_cvdump.globals
if recompfile.is_valid_section(g.section)
}
for var in isle_compare.get_variables():
type_name = recomp_type_reference.get(var.recomp_addr)
# Start by assuming we can only compare the raw bytes
data_size = var.size
is_type_aware = type_name is not None
if is_type_aware:
try:
# If we are type-aware, we can get the precise
# data size for the variable.
data_type = mini_cvdump.types.get(type_name)
data_size = data_type.size
except (CvdumpKeyError, CvdumpIntegrityError) as ex:
yield create_comparison_item(var, error=repr(ex))
continue
orig_raw = origfile.read(var.orig_addr, data_size)
recomp_raw = recompfile.read(var.recomp_addr, data_size)
# The IMAGE_SECTION_HEADER defines the SizeOfRawData and VirtualSize for the section.
# If VirtualSize > SizeOfRawData, the section is comprised of the initialized data
# corresponding to bytes in the file, and the rest is padded with zeroes when
# Windows loads the image.
# The linker might place variables initialized to zero on the threshold between
# physical data and the virtual (uninitialized) data.
# If this happens (i.e. we get an incomplete read) we just do the same padding
# to prepare for the comparison.
if orig_raw is not None and len(orig_raw) < data_size:
orig_raw = orig_raw.ljust(data_size, b"\x00")
if recomp_raw is not None and len(recomp_raw) < data_size:
recomp_raw = recomp_raw.ljust(data_size, b"\x00")
# If one or both variables are entirely uninitialized
if orig_raw is None or recomp_raw is None:
# If both variables are uninitialized, we consider them equal.
match = orig_raw is None and recomp_raw is None
# We can match a variable initialized to all zeroes with
# an uninitialized variable, but this may or may not actually
# be correct, so we flag it for the user.
uninit_force_match = not match and (
(orig_raw is None and all(b == 0 for b in recomp_raw))
or (recomp_raw is None and all(b == 0 for b in orig_raw))
)
orig_value = "(uninitialized)" if orig_raw is None else "(initialized)"
recomp_value = (
"(uninitialized)" if recomp_raw is None else "(initialized)"
)
yield create_comparison_item(
var,
compared=[
ComparedOffset(
offset=0,
name=None,
match=match,
values=(orig_value, recomp_value),
)
],
raw_only=uninit_force_match,
)
continue
if not is_type_aware:
# If there is no specific type information available
# (i.e. if this is a static or non-public variable)
# then we can only compare the raw bytes.
yield create_comparison_item(
var,
compared=[
ComparedOffset(
offset=0,
name="(raw)",
match=orig_raw == recomp_raw,
values=(orig_raw, recomp_raw),
)
],
raw_only=True,
)
continue
# If we are here, we can do the type-aware comparison.
compared = []
compare_items = mini_cvdump.types.get_scalars_gapless(type_name)
format_str = mini_cvdump.types.get_format_string(type_name)
orig_data = unpack(format_str, orig_raw)
recomp_data = unpack(format_str, recomp_raw)
def pointer_display(addr: int, is_orig: bool) -> str:
"""Helper to streamline pointer textual display."""
if addr == 0:
return "nullptr"
ptr_match = (
isle_compare.get_by_orig(addr)
if is_orig
else isle_compare.get_by_recomp(addr)
)
if ptr_match is not None:
return f"Pointer to {ptr_match.match_name()}"
# This variable did not match if we do not have
# the pointer target in our DB.
return f"Unknown pointer 0x{addr:x}"
# Could zip here
for i, member in enumerate(compare_items):
if member.is_pointer:
match = isle_compare.is_pointer_match(orig_data[i], recomp_data[i])
value_a = pointer_display(orig_data[i], True)
value_b = pointer_display(recomp_data[i], False)
values = (value_a, value_b)
else:
match = orig_data[i] == recomp_data[i]
values = (orig_data[i], recomp_data[i])
compared.append(
ComparedOffset(
offset=member.offset,
name=member.name,
match=match,
values=values,
)
)
yield create_comparison_item(var, compared=compared)
def value_get(value: Optional[str], default: str):
return value if value is not None else default
def main():
args = parse_args()
def display_match(result: CompareResult) -> str:
"""Helper to return color string or not, depending on user preference"""
if args.no_color:
return result.name
match_color = (
colorama.Fore.GREEN
if result == CompareResult.MATCH
else (
colorama.Fore.YELLOW
if result == CompareResult.WARN
else colorama.Fore.RED
)
)
return f"{match_color}{result.name}{colorama.Style.RESET_ALL}"
var_count = 0
problems = 0
for item in do_the_comparison(args):
var_count += 1
if item.result in (CompareResult.DIFF, CompareResult.ERROR):
problems += 1
if not args.show_all and item.result == CompareResult.MATCH:
continue
address_display = (
f"0x{item.orig_addr:x} / 0x{item.recomp_addr:x}"
if args.print_rec_addr
else f"0x{item.orig_addr:x}"
)
print(f"{item.name[:80]} ({address_display}) ... {display_match(item.result)} ")
if item.error is not None:
print(f" {item.error}")
for c in item.compared:
if not args.verbose and c.match:
continue
(value_a, value_b) = c.values
if c.match:
print(f" {c.offset:5} {value_get(c.name, '(value)'):30} {value_a}")
else:
print(
f" {c.offset:5} {value_get(c.name, '(value)'):30} {value_a} : {value_b}"
)
if args.verbose:
print()
print(
f"{os.path.basename(args.original)} - Variables: {var_count}. Issues: {problems}"
)
return 0 if problems == 0 else 1
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -1,103 +0,0 @@
#!/usr/bin/env python3
import os
import sys
import argparse
import colorama
from isledecomp.dir import walk_source_dir, is_file_cpp
from isledecomp.parser import DecompLinter
colorama.just_fix_windows_console()
def display_errors(alerts, filename):
sorted_alerts = sorted(alerts, key=lambda a: a.line_number)
for alert in sorted_alerts:
error_type = (
f"{colorama.Fore.RED}error: "
if alert.is_error()
else f"{colorama.Fore.YELLOW}warning: "
)
components = [
colorama.Fore.LIGHTWHITE_EX,
filename,
":",
str(alert.line_number),
" : ",
error_type,
colorama.Fore.LIGHTWHITE_EX,
alert.code.name.lower(),
]
print("".join(components))
if alert.line is not None:
print(f"{colorama.Fore.WHITE} {alert.line}")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Syntax checking and linting for decomp annotation markers."
)
p.add_argument("target", help="The file or directory to check.")
p.add_argument(
"--module",
required=False,
type=str,
help="If present, run targeted checks for markers from the given module.",
)
p.add_argument(
"--warnfail",
action=argparse.BooleanOptionalAction,
default=False,
help="Fail if syntax warnings are found.",
)
(args, _) = p.parse_known_args()
return args
def process_files(files, module=None):
warning_count = 0
error_count = 0
linter = DecompLinter()
for filename in files:
success = linter.check_file(filename, module)
warnings = [a for a in linter.alerts if a.is_warning()]
errors = [a for a in linter.alerts if a.is_error()]
error_count += len(errors)
warning_count += len(warnings)
if not success:
display_errors(linter.alerts, filename)
print()
return (warning_count, error_count)
def main():
args = parse_args()
files_to_check = []
if os.path.isdir(args.target):
files_to_check = list(walk_source_dir(args.target))
elif os.path.isfile(args.target) and is_file_cpp(args.target):
files_to_check = [args.target]
else:
sys.exit("Invalid target")
(warning_count, error_count) = process_files(files_to_check, module=args.module)
print(colorama.Style.RESET_ALL, end="")
would_fail = error_count > 0 or (warning_count > 0 and args.warnfail)
if would_fail:
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -1,25 +0,0 @@
# Ghidra Scripts
The scripts in this directory provide additional functionality in Ghidra, e.g. imports of symbols and types from the PDB debug symbol file.
## Setup
### Ghidrathon
Since these scripts and its dependencies are written in Python 3, [Ghidrathon](https://github.com/mandiant/Ghidrathon) must be installed first. Follow the instructions and install a recent build (these scripts were tested with Python 3.12 and Ghidrathon v4.0.0).
### Script Directory
- In Ghidra, _Open Window -> Script Manager_.
- Click the _Manage Script Directories_ button on the top right.
- Click the _Add_ (Plus icon) button and select this file's parent directory.
- Close the window and click the _Refresh_ button.
- This script should now be available under the folder _LEGO1_.
### Virtual environment
As of now, there must be a Python virtual environment set up under `$REPOSITORY_ROOT/.venv`, and the dependencies of `isledecomp` must be installed there, see [here](../README.md#tooling).
## Development
- Type hints for Ghidra (optional): Download a recent release from https://github.com/VDOO-Connected-Trust/ghidra-pyi-generator,
unpack it somewhere, and `pip install` that directory in this virtual environment. This provides types and headers for Python.
Be aware that some of these files contain errors - in particular, `from typing import overload` seems to be missing everywhere, leading to spurious type errors.
- Note that the imported modules persist across multiple runs of the script (see [here](https://github.com/mandiant/Ghidrathon/issues/103)).
If you indend to modify an imported library, you have to use `import importlib; importlib.reload(${library})` or restart Ghidra for your changes to have any effect. Unfortunately, even that is not perfectly reliable, so you may still have to restart Ghidra for some changes in `isledecomp` to be applied.

View file

@ -1,285 +0,0 @@
# Imports types and function signatures from debug symbols (PDB file) of the recompilation.
#
# This script uses Python 3 and therefore requires Ghidrathon to be installed in Ghidra (see https://github.com/mandiant/Ghidrathon).
# Furthermore, the virtual environment must be set up beforehand under $REPOSITORY_ROOT/.venv, and all required packages must be installed
# (see $REPOSITORY_ROOT/tools/README.md).
# Also, the Python version of the virtual environment must probably match the Python version used for Ghidrathon.
# @author J. Schulz
# @category LEGO1
# @keybinding
# @menupath
# @toolbar
# In order to make this code run both within and outside of Ghidra, the import order is rather unorthodox in this file.
# That is why some of the lints below are disabled.
# pylint: disable=wrong-import-position,ungrouped-imports
# pylint: disable=undefined-variable # need to disable this one globally because pylint does not understand e.g. `askYesNo()``
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
import importlib
import logging.handlers
import sys
import logging
from pathlib import Path
import traceback
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
import ghidra
from lego_util.headers import * # pylint: disable=wildcard-import # these are just for headers
logger = logging.getLogger(__name__)
def reload_module(module: str):
"""
Due to a a quirk in Jep (used by Ghidrathon), imported modules persist for the lifetime of the Ghidra process
and are not reloaded when relaunching the script. Therefore, in order to facilitate development
we force reload all our own modules at startup. See also https://github.com/mandiant/Ghidrathon/issues/103.
Note that as of 2024-05-30, this remedy does not work perfectly (yet): Some changes in isledecomp are
still not detected correctly and require a Ghidra restart to be applied.
"""
importlib.reload(importlib.import_module(module))
reload_module("lego_util.statistics")
reload_module("lego_util.globals")
from lego_util.globals import GLOBALS, SupportedModules
def setup_logging():
logging.root.handlers.clear()
formatter = logging.Formatter("%(levelname)-8s %(message)s")
# formatter = logging.Formatter("%(name)s %(levelname)-8s %(message)s") # use this to identify loggers
stdout_handler = logging.StreamHandler(sys.stdout)
stdout_handler.setFormatter(formatter)
file_handler = logging.FileHandler(
Path(__file__).absolute().parent.joinpath("import.log"), mode="w"
)
file_handler.setFormatter(formatter)
logging.root.setLevel(GLOBALS.loglevel)
logging.root.addHandler(stdout_handler)
logging.root.addHandler(file_handler)
logger.info("Starting import...")
# This script can be run both from Ghidra and as a standalone.
# In the latter case, only the PDB parser will be used.
setup_logging()
try:
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.util.exception import CancelledException
GLOBALS.running_from_ghidra = True
except ImportError as importError:
logger.error(
"Failed to import Ghidra functions, doing a dry run for the source code parser. "
"Has this script been launched from Ghidra?"
)
logger.debug("Precise import error:", exc_info=importError)
GLOBALS.running_from_ghidra = False
CancelledException = None
def get_repository_root():
return Path(__file__).absolute().parent.parent.parent
def add_python_path(path: str):
"""
Scripts in Ghidra are executed from the tools/ghidra_scripts directory. We need to add
a few more paths to the Python path so we can import the other libraries.
"""
venv_path = get_repository_root().joinpath(path)
logger.info("Adding %s to Python Path", venv_path)
assert venv_path.exists()
sys.path.insert(1, str(venv_path))
# We need to quote the types here because they might not exist when running without Ghidra
def import_function_into_ghidra(
api: "FlatProgramAPI",
pdb_function: "PdbFunction",
type_importer: "PdbTypeImporter",
):
hex_original_address = f"{pdb_function.match_info.orig_addr:x}"
# Find the Ghidra function at that address
ghidra_address = getAddressFactory().getAddress(hex_original_address)
# pylint: disable=possibly-used-before-assignment
function_importer = PdbFunctionImporter.build(api, pdb_function, type_importer)
ghidra_function = getFunctionAt(ghidra_address)
if ghidra_function is None:
ghidra_function = createFunction(ghidra_address, "temp")
assert (
ghidra_function is not None
), f"Failed to create function at {ghidra_address}"
logger.info("Created new function at %s", ghidra_address)
logger.debug("Start handling function '%s'", function_importer.get_full_name())
if function_importer.matches_ghidra_function(ghidra_function):
logger.info(
"Skipping function '%s', matches already",
function_importer.get_full_name(),
)
return
logger.debug(
"Modifying function %s at 0x%s",
function_importer.get_full_name(),
hex_original_address,
)
function_importer.overwrite_ghidra_function(ghidra_function)
GLOBALS.statistics.functions_changed += 1
def process_functions(extraction: "PdbFunctionExtractor"):
pdb_functions = extraction.get_function_list()
if not GLOBALS.running_from_ghidra:
logger.info("Completed the dry run outside Ghidra.")
return
api = FlatProgramAPI(currentProgram())
# pylint: disable=possibly-used-before-assignment
type_importer = PdbTypeImporter(api, extraction)
for pdb_func in pdb_functions:
func_name = pdb_func.match_info.name
try:
import_function_into_ghidra(api, pdb_func, type_importer)
GLOBALS.statistics.successes += 1
except Lego1Exception as e:
log_and_track_failure(func_name, e)
except RuntimeError as e:
cause = e.args[0]
if CancelledException is not None and isinstance(cause, CancelledException):
# let Ghidra's CancelledException pass through
logging.critical("Import aborted by the user.")
return
log_and_track_failure(func_name, cause, unexpected=True)
logger.error(traceback.format_exc())
except Exception as e: # pylint: disable=broad-exception-caught
log_and_track_failure(func_name, e, unexpected=True)
logger.error(traceback.format_exc())
def log_and_track_failure(
function_name: Optional[str], error: Exception, unexpected: bool = False
):
if GLOBALS.statistics.track_failure_and_tell_if_new(error):
logger.error(
"%s(): %s%s",
function_name,
"Unexpected error: " if unexpected else "",
error,
)
def main():
if GLOBALS.running_from_ghidra:
origfile_name = getProgramFile().getName()
if origfile_name == "LEGO1.DLL":
GLOBALS.module = SupportedModules.LEGO1
elif origfile_name in ["LEGO1D.DLL", "BETA10.DLL"]:
GLOBALS.module = SupportedModules.BETA10
else:
raise Lego1Exception(
f"Unsupported file name in import script: {origfile_name}"
)
logger.info("Importing file: %s", GLOBALS.module.orig_filename())
repo_root = get_repository_root()
origfile_path = repo_root.joinpath("legobin").joinpath(
GLOBALS.module.orig_filename()
)
build_directory = repo_root.joinpath(GLOBALS.module.build_dir_name())
recompiledfile_name = f"{GLOBALS.module.recomp_filename_without_extension()}.DLL"
recompiledfile_path = build_directory.joinpath(recompiledfile_name)
pdbfile_name = f"{GLOBALS.module.recomp_filename_without_extension()}.PDB"
pdbfile_path = build_directory.joinpath(pdbfile_name)
if not GLOBALS.verbose:
logging.getLogger("isledecomp.bin").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.core").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.db").setLevel(logging.WARNING)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.WARNING)
logging.getLogger("isledecomp.cvdump.symbols").setLevel(logging.WARNING)
logger.info("Starting comparison")
with Bin(str(origfile_path), find_str=True) as origfile, Bin(
str(recompiledfile_path)
) as recompfile:
isle_compare = IsleCompare(
origfile, recompfile, str(pdbfile_path), str(repo_root)
)
logger.info("Comparison complete.")
# try to acquire matched functions
migration = PdbFunctionExtractor(isle_compare)
try:
process_functions(migration)
finally:
if GLOBALS.running_from_ghidra:
GLOBALS.statistics.log()
logger.info("Done")
# sys.path is not reset after running the script, so we should restore it
sys_path_backup = sys.path.copy()
try:
# make modules installed in the venv available in Ghidra
add_python_path(".venv/Lib/site-packages")
# This one is needed when isledecomp is installed in editable mode in the venv
add_python_path("tools/isledecomp")
import setuptools # pylint: disable=unused-import # required to fix a distutils issue in Python 3.12
reload_module("isledecomp")
from isledecomp import Bin
reload_module("isledecomp.compare")
from isledecomp.compare import Compare as IsleCompare
reload_module("isledecomp.compare.db")
reload_module("lego_util.exceptions")
from lego_util.exceptions import Lego1Exception
reload_module("lego_util.pdb_extraction")
from lego_util.pdb_extraction import (
PdbFunctionExtractor,
PdbFunction,
)
if GLOBALS.running_from_ghidra:
reload_module("lego_util.ghidra_helper")
reload_module("lego_util.function_importer")
from lego_util.function_importer import PdbFunctionImporter
reload_module("lego_util.type_importer")
from lego_util.type_importer import PdbTypeImporter
if __name__ == "__main__":
main()
finally:
sys.path = sys_path_backup

View file

@ -1,47 +0,0 @@
class Lego1Exception(Exception):
"""
Our own base class for exceptions.
Makes it easier to distinguish expected and unexpected errors.
"""
class TypeNotFoundError(Lego1Exception):
def __str__(self):
return f"Type not found in PDB: {self.args[0]}"
class TypeNotFoundInGhidraError(Lego1Exception):
def __str__(self):
return f"Type not found in Ghidra: {self.args[0]}"
class TypeNotImplementedError(Lego1Exception):
def __str__(self):
return f"Import not implemented for type: {self.args[0]}"
class ClassOrNamespaceNotFoundInGhidraError(Lego1Exception):
def __init__(self, namespaceHierachy: list[str]):
super().__init__(namespaceHierachy)
def get_namespace_str(self) -> str:
return "::".join(self.args[0])
def __str__(self):
return f"Class or namespace not found in Ghidra: {self.get_namespace_str()}"
class MultipleTypesFoundInGhidraError(Lego1Exception):
def __str__(self):
return (
f"Found multiple types matching '{self.args[0]}' in Ghidra: {self.args[1]}"
)
class StackOffsetMismatchError(Lego1Exception):
pass
class StructModificationError(Lego1Exception):
def __str__(self):
return f"Failed to modify struct in Ghidra: '{self.args[0]}'\nDetailed error: {self.__cause__}"

View file

@ -1,421 +0,0 @@
# This file can only be imported successfully when run from Ghidra using Ghidrathon.
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
import logging
from typing import Optional
from abc import ABC, abstractmethod
from ghidra.program.model.listing import Function, Parameter
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.listing import ParameterImpl
from ghidra.program.model.symbol import SourceType
from ghidra.program.model.data import (
TypeDef,
TypedefDataType,
Pointer,
ComponentOffsetSettingsDefinition,
)
from lego_util.pdb_extraction import (
PdbFunction,
CppRegisterSymbol,
CppStackSymbol,
)
from lego_util.ghidra_helper import (
add_data_type_or_reuse_existing,
create_ghidra_namespace,
get_or_add_pointer_type,
get_ghidra_namespace,
sanitize_name,
)
from lego_util.exceptions import StackOffsetMismatchError, Lego1Exception
from lego_util.type_importer import PdbTypeImporter
logger = logging.getLogger(__name__)
class PdbFunctionImporter(ABC):
"""A representation of a function from the PDB with each type replaced by a Ghidra type instance."""
def __init__(
self,
api: FlatProgramAPI,
func: PdbFunction,
type_importer: "PdbTypeImporter",
):
self.api = api
self.match_info = func.match_info
self.type_importer = type_importer
assert self.match_info.name is not None
colon_split = sanitize_name(self.match_info.name).split("::")
self.name = colon_split.pop()
namespace_hierachy = colon_split
self.namespace = self._do_get_namespace(namespace_hierachy)
def _do_get_namespace(self, namespace_hierarchy: list[str]):
return get_ghidra_namespace(self.api, namespace_hierarchy)
def get_full_name(self) -> str:
return f"{self.namespace.getName()}::{self.name}"
@staticmethod
def build(api: FlatProgramAPI, func: PdbFunction, type_importer: "PdbTypeImporter"):
return (
ThunkPdbFunctionImport(api, func, type_importer)
if func.signature is None
else FullPdbFunctionImporter(api, func, type_importer)
)
@abstractmethod
def matches_ghidra_function(self, ghidra_function: Function) -> bool:
...
@abstractmethod
def overwrite_ghidra_function(self, ghidra_function: Function):
...
class ThunkPdbFunctionImport(PdbFunctionImporter):
"""For importing thunk functions (like vtordisp or debug build thunks) into Ghidra.
Only the name of the function will be imported."""
def _do_get_namespace(self, namespace_hierarchy: list[str]):
"""We need to create the namespace because we don't import the return type here"""
return create_ghidra_namespace(self.api, namespace_hierarchy)
def matches_ghidra_function(self, ghidra_function: Function) -> bool:
name_match = self.name == ghidra_function.getName(False)
namespace_match = self.namespace == ghidra_function.getParentNamespace()
logger.debug("Matches: namespace=%s name=%s", namespace_match, name_match)
return name_match and namespace_match
def overwrite_ghidra_function(self, ghidra_function: Function):
ghidra_function.setName(self.name, SourceType.USER_DEFINED)
ghidra_function.setParentNamespace(self.namespace)
# pylint: disable=too-many-instance-attributes
class FullPdbFunctionImporter(PdbFunctionImporter):
"""For importing functions into Ghidra where all information are available."""
def __init__(
self,
api: FlatProgramAPI,
func: PdbFunction,
type_importer: "PdbTypeImporter",
):
super().__init__(api, func, type_importer)
assert func.signature is not None
self.signature = func.signature
self.is_stub = func.is_stub
if self.signature.class_type is not None:
# Import the base class so the namespace exists
self.type_importer.import_pdb_type_into_ghidra(self.signature.class_type)
self.return_type = type_importer.import_pdb_type_into_ghidra(
self.signature.return_type
)
self.arguments = [
ParameterImpl(
f"param{index}",
type_importer.import_pdb_type_into_ghidra(type_name),
api.getCurrentProgram(),
)
for (index, type_name) in enumerate(self.signature.arglist)
]
def matches_ghidra_function(self, ghidra_function: Function) -> bool:
"""Checks whether this function declaration already matches the description in Ghidra"""
name_match = self.name == ghidra_function.getName(False)
namespace_match = self.namespace == ghidra_function.getParentNamespace()
ghidra_return_type = ghidra_function.getReturnType()
return_type_match = self.return_type == ghidra_return_type
# Handle edge case: Return type X that is larger than the return register.
# In that case, the function returns `X*` and has another argument `X* __return_storage_ptr`.
if (
(not return_type_match)
and (self.return_type.getLength() > 4)
and (
get_or_add_pointer_type(self.api, self.return_type)
== ghidra_return_type
)
and any(
param
for param in ghidra_function.getParameters()
if param.getName() == "__return_storage_ptr__"
)
):
logger.debug(
"%s has a return type larger than 4 bytes", self.get_full_name()
)
return_type_match = True
# match arguments: decide if thiscall or not, and whether the `this` type matches
calling_convention_match = (
self.signature.call_type == ghidra_function.getCallingConventionName()
)
ghidra_params_without_this = list(ghidra_function.getParameters())
if calling_convention_match and self.signature.call_type == "__thiscall":
this_argument = ghidra_params_without_this.pop(0)
calling_convention_match = self._this_type_match(this_argument)
if self.is_stub:
# We do not import the argument list for stubs, so it should be excluded in matches
args_match = True
elif calling_convention_match:
args_match = self._parameter_lists_match(ghidra_params_without_this)
else:
args_match = False
logger.debug(
"Matches: namespace=%s name=%s return_type=%s calling_convention=%s args=%s",
namespace_match,
name_match,
return_type_match,
calling_convention_match,
"ignored" if self.is_stub else args_match,
)
return (
name_match
and namespace_match
and return_type_match
and calling_convention_match
and args_match
)
def _this_type_match(self, this_parameter: Parameter) -> bool:
if this_parameter.getName() != "this":
logger.info("Expected first argument to be `this` in __thiscall")
return False
if self.signature.this_adjust != 0:
# In this case, the `this` argument should be custom defined
if not isinstance(this_parameter.getDataType(), TypeDef):
logger.info(
"`this` argument is not a typedef while `this adjust` = %d",
self.signature.this_adjust,
)
return False
# We are not checking for the _correct_ `this` type here, which we could do in the future
return True
def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool:
# Remove return storage pointer from comparison if present.
# This is relevant to returning values larger than 4 bytes, and is not mentioned in the PDB
ghidra_params = [
param
for param in ghidra_params
if param.getName() != "__return_storage_ptr__"
]
if len(self.arguments) != len(ghidra_params):
logger.info("Mismatching argument count")
return False
for this_arg, ghidra_arg in zip(self.arguments, ghidra_params):
# compare argument types
if this_arg.getDataType() != ghidra_arg.getDataType():
logger.debug(
"Mismatching arg type: expected %s, found %s",
this_arg.getDataType(),
ghidra_arg.getDataType(),
)
return False
# compare argument names
stack_match = self.get_matching_stack_symbol(ghidra_arg.getStackOffset())
if stack_match is None:
logger.debug("Not found on stack: %s", ghidra_arg)
return False
if stack_match.name.startswith("__formal"):
# "__formal" is the placeholder for arguments without a name
continue
if stack_match.name == "__$ReturnUdt":
# These appear in templates and cannot be set automatically, as they are a NOTYPE
continue
if stack_match.name != ghidra_arg.getName():
logger.debug(
"Argument name mismatch: expected %s, found %s",
stack_match.name,
ghidra_arg.getName(),
)
return False
return True
def overwrite_ghidra_function(self, ghidra_function: Function):
"""Replace the function declaration in Ghidra by the one derived from C++."""
if ghidra_function.hasCustomVariableStorage():
# Unfortunately, calling `ghidra_function.setCustomVariableStorage(False)`
# leads to two `this` parameters. Therefore, we first need to remove all `this` parameters
# and then re-generate a new one
ghidra_function.replaceParameters(
Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS, # this implicitly sets custom variable storage to False
True,
SourceType.USER_DEFINED,
[
param
for param in ghidra_function.getParameters()
if param.getName() != "this"
],
)
if ghidra_function.hasCustomVariableStorage():
raise Lego1Exception("Failed to disable custom variable storage.")
ghidra_function.setName(self.name, SourceType.USER_DEFINED)
ghidra_function.setParentNamespace(self.namespace)
ghidra_function.setReturnType(self.return_type, SourceType.USER_DEFINED)
ghidra_function.setCallingConvention(self.signature.call_type)
if self.is_stub:
logger.debug(
"%s is a stub, skipping parameter import", self.get_full_name()
)
else:
ghidra_function.replaceParameters(
Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS,
True, # force
SourceType.USER_DEFINED,
self.arguments,
)
self._import_parameter_names(ghidra_function)
# Special handling for `this adjust` and virtual inheritance
if self.signature.this_adjust != 0:
self._set_this_adjust(ghidra_function)
def _import_parameter_names(self, ghidra_function: Function):
# When we call `ghidra_function.replaceParameters`, Ghidra will generate the layout.
# Now we read the parameters again and match them against the stack layout in the PDB,
# both to verify the layout and to set the parameter names.
ghidra_parameters: list[Parameter] = ghidra_function.getParameters()
# Try to add Ghidra function names
for index, param in enumerate(ghidra_parameters):
if param.isStackVariable():
self._rename_stack_parameter(index, param)
else:
if param.getName() == "this":
# 'this' parameters are auto-generated and cannot be changed
continue
# Appears to never happen - could in theory be relevant to __fastcall__ functions,
# which we haven't seen yet
logger.warning(
"Unhandled register variable in %s", self.get_full_name()
)
continue
def _rename_stack_parameter(self, index: int, param: Parameter):
match = self.get_matching_stack_symbol(param.getStackOffset())
if match is None:
raise StackOffsetMismatchError(
f"Could not find a matching symbol at offset {param.getStackOffset()} in {self.get_full_name()}"
)
if match.data_type == "T_NOTYPE(0000)":
logger.warning("Skipping stack parameter of type NOTYPE")
return
if param.getDataType() != self.type_importer.import_pdb_type_into_ghidra(
match.data_type
):
logger.error(
"Type mismatch for parameter: %s in Ghidra, %s in PDB", param, match
)
return
name = match.name
if name == "__formal":
# these can cause name collisions if multiple ones are present
name = f"__formal_{index}"
param.setName(name, SourceType.USER_DEFINED)
def get_matching_stack_symbol(self, stack_offset: int) -> Optional[CppStackSymbol]:
return next(
(
symbol
for symbol in self.signature.stack_symbols
if isinstance(symbol, CppStackSymbol)
and symbol.stack_offset == stack_offset
),
None,
)
def get_matching_register_symbol(
self, register: str
) -> Optional[CppRegisterSymbol]:
return next(
(
symbol
for symbol in self.signature.stack_symbols
if isinstance(symbol, CppRegisterSymbol) and symbol.register == register
),
None,
)
def _set_this_adjust(
self,
ghidra_function: Function,
):
"""
When `this adjust` is non-zero, the pointer type of `this` needs to be replaced by an offset version.
The offset can only be set on a typedef on the pointer. We also must enable custom storage so we can modify
the auto-generated `this` parameter.
"""
# Necessary in order to overwite the auto-generated `this`
ghidra_function.setCustomVariableStorage(True)
this_parameter = next(
(
param
for param in ghidra_function.getParameters()
if param.isRegisterVariable() and param.getName() == "this"
),
None,
)
if this_parameter is None:
logger.error(
"Failed to find `this` parameter in a function with `this adjust = %d`",
self.signature.this_adjust,
)
else:
current_ghidra_type = this_parameter.getDataType()
assert isinstance(current_ghidra_type, Pointer)
class_name = current_ghidra_type.getDataType().getName()
typedef_name = f"{class_name}PtrOffset0x{self.signature.this_adjust:x}"
typedef_ghidra_type = TypedefDataType(
current_ghidra_type.getCategoryPath(),
typedef_name,
current_ghidra_type,
)
ComponentOffsetSettingsDefinition.DEF.setValue(
typedef_ghidra_type.getDefaultSettings(), self.signature.this_adjust
)
typedef_ghidra_type = add_data_type_or_reuse_existing(
self.api, typedef_ghidra_type
)
this_parameter.setDataType(typedef_ghidra_type, SourceType.USER_DEFINED)

View file

@ -1,129 +0,0 @@
"""A collection of helper functions for the interaction with Ghidra."""
import logging
import re
from lego_util.exceptions import (
ClassOrNamespaceNotFoundInGhidraError,
TypeNotFoundInGhidraError,
MultipleTypesFoundInGhidraError,
)
from lego_util.globals import GLOBALS, SupportedModules
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.data import DataType, DataTypeConflictHandler, PointerDataType
from ghidra.program.model.symbol import Namespace
logger = logging.getLogger(__name__)
def get_ghidra_type(api: FlatProgramAPI, type_name: str):
"""
Searches for the type named `typeName` in Ghidra.
Raises:
- NotFoundInGhidraError
- MultipleTypesFoundInGhidraError
"""
result = api.getDataTypes(type_name)
if len(result) == 0:
raise TypeNotFoundInGhidraError(type_name)
if len(result) == 1:
return result[0]
raise MultipleTypesFoundInGhidraError(type_name, result)
def get_or_add_pointer_type(api: FlatProgramAPI, pointee: DataType) -> DataType:
new_pointer_data_type = PointerDataType(pointee)
new_pointer_data_type.setCategoryPath(pointee.getCategoryPath())
return add_data_type_or_reuse_existing(api, new_pointer_data_type)
def add_data_type_or_reuse_existing(
api: FlatProgramAPI, new_data_type: DataType
) -> DataType:
result_data_type = (
api.getCurrentProgram()
.getDataTypeManager()
.addDataType(new_data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
if result_data_type is not new_data_type:
logger.debug(
"Reusing existing data type instead of new one: %s (class: %s)",
result_data_type,
result_data_type.__class__,
)
return result_data_type
def get_ghidra_namespace(
api: FlatProgramAPI, namespace_hierachy: list[str]
) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)
if namespace is None:
raise ClassOrNamespaceNotFoundInGhidraError(namespace_hierachy)
return namespace
def create_ghidra_namespace(
api: FlatProgramAPI, namespace_hierachy: list[str]
) -> Namespace:
namespace = api.getCurrentProgram().getGlobalNamespace()
for part in namespace_hierachy:
namespace = api.getNamespace(namespace, part)
if namespace is None:
namespace = api.createNamespace(namespace, part)
return namespace
# These appear in debug builds
THUNK_OF_RE = re.compile(r"^Thunk of '(.*)'$")
def sanitize_name(name: str) -> str:
"""
Takes a full class or function name and replaces characters not accepted by Ghidra.
Applies mostly to templates, names like `vbase destructor`, and thunks in debug build.
"""
if (match := THUNK_OF_RE.fullmatch(name)) is not None:
is_thunk = True
name = match.group(1)
else:
is_thunk = False
# Replace characters forbidden in Ghidra
new_name = (
name.replace("<", "[")
.replace(">", "]")
.replace("*", "#")
.replace(" ", "_")
.replace("`", "'")
)
# Importing function names like `FUN_10001234` into BETA10 can be confusing
# because Ghidra's auto-generated functions look exactly the same.
# Therefore, such function names are replaced by `LEGO_10001234` in the BETA10 import.
if GLOBALS.module == SupportedModules.BETA10:
new_name = re.sub(r"FUN_([0-9a-f]{8})", r"LEGO1_\1", new_name)
if "<" in name:
new_name = "_template_" + new_name
if is_thunk:
split = new_name.split("::")
split[-1] = "_thunk_" + split[-1]
new_name = "::".join(split)
if new_name != name:
logger.info(
"Changed class or function name from '%s' to '%s' to avoid Ghidra issues",
name,
new_name,
)
return new_name

View file

@ -1,42 +0,0 @@
import logging
from enum import Enum
from dataclasses import dataclass, field
from lego_util.statistics import Statistics
class SupportedModules(Enum):
LEGO1 = 1
BETA10 = 2
def orig_filename(self):
if self == self.LEGO1:
return "LEGO1.DLL"
return "BETA10.DLL"
def recomp_filename_without_extension(self):
# in case we want to support more functions
return "LEGO1"
def build_dir_name(self):
if self == self.BETA10:
return "build_debug"
return "build"
@dataclass
class Globals:
verbose: bool
loglevel: int
module: SupportedModules
running_from_ghidra: bool = False
# statistics
statistics: Statistics = field(default_factory=Statistics)
# hard-coded settings that we don't want to prompt in Ghidra every time
GLOBALS = Globals(
verbose=False,
# loglevel=logging.INFO,
loglevel=logging.DEBUG,
module=SupportedModules.LEGO1, # this default value will be used when run outside of Ghidra
)

View file

@ -1,20 +0,0 @@
from typing import TypeVar, Any
import ghidra
# pylint: disable=invalid-name,unused-argument
T = TypeVar("T")
# from ghidra.app.script.GhidraScript
def currentProgram() -> "ghidra.program.model.listing.Program": ...
def getAddressFactory() -> " ghidra.program.model.address.AddressFactory": ...
def state() -> "ghidra.app.script.GhidraState": ...
def askChoice(title: str, message: str, choices: list[T], defaultValue: T) -> T: ...
def askYesNo(title: str, question: str) -> bool: ...
def getFunctionAt(
entryPoint: ghidra.program.model.address.Address,
) -> ghidra.program.model.listing.Function: ...
def createFunction(
entryPoint: ghidra.program.model.address.Address, name: str
) -> ghidra.program.model.listing.Function: ...
def getProgramFile() -> Any: ... # actually java.io.File

View file

@ -1,183 +0,0 @@
from dataclasses import dataclass
import re
from typing import Any, Optional
import logging
from isledecomp.bin import InvalidVirtualAddressError
from isledecomp.cvdump.symbols import SymbolsEntry
from isledecomp.compare import Compare as IsleCompare
from isledecomp.compare.db import MatchInfo
logger = logging.getLogger(__file__)
@dataclass
class CppStackOrRegisterSymbol:
name: str
data_type: str
@dataclass
class CppStackSymbol(CppStackOrRegisterSymbol):
stack_offset: int
"""Should have a value iff `symbol_type=='S_BPREL32'."""
@dataclass
class CppRegisterSymbol(CppStackOrRegisterSymbol):
register: str
"""Should have a value iff `symbol_type=='S_REGISTER'.` Should always be set/converted to lowercase."""
@dataclass
class FunctionSignature:
original_function_symbol: SymbolsEntry
call_type: str
arglist: list[str]
return_type: str
class_type: Optional[str]
stack_symbols: list[CppStackOrRegisterSymbol]
# if non-zero: an offset to the `this` parameter in a __thiscall
this_adjust: int
@dataclass
class PdbFunction:
match_info: MatchInfo
signature: Optional[FunctionSignature]
is_stub: bool
class PdbFunctionExtractor:
"""
Extracts all information on a given function from the parsed PDB
and prepares the data for the import in Ghidra.
"""
def __init__(self, compare: IsleCompare):
self.compare = compare
scalar_type_regex = re.compile(r"t_(?P<typename>\w+)(?:\((?P<type_id>\d+)\))?")
_call_type_map = {
"ThisCall": "__thiscall",
"C Near": "default",
"STD Near": "__stdcall",
}
def _get_cvdump_type(self, type_name: Optional[str]) -> Optional[dict[str, Any]]:
return (
None
if type_name is None
else self.compare.cv.types.keys.get(type_name.lower())
)
def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
function_type_str = fn.func_type
if function_type_str == "T_NOTYPE(0000)":
logger.debug("Treating NOTYPE function as thunk: %s", fn.name)
return None
# get corresponding function type
function_type = self.compare.cv.types.keys.get(function_type_str.lower())
if function_type is None:
logger.error(
"Could not find function type %s for function %s", fn.func_type, fn.name
)
return None
class_type = function_type.get("class_type")
arg_list_type = self._get_cvdump_type(function_type.get("arg_list_type"))
assert arg_list_type is not None
arg_list_pdb_types = arg_list_type.get("args", [])
assert arg_list_type["argcount"] == len(arg_list_pdb_types)
stack_symbols: list[CppStackOrRegisterSymbol] = []
# for some unexplained reason, the reported stack is offset by 4 when this flag is set.
# Note that this affects the arguments (ebp + ...) but not the function stack (ebp - ...)
stack_offset_delta = -4 if fn.frame_pointer_present else 0
for symbol in fn.stack_symbols:
if symbol.symbol_type == "S_REGISTER":
stack_symbols.append(
CppRegisterSymbol(
symbol.name,
symbol.data_type,
symbol.location,
)
)
elif symbol.symbol_type == "S_BPREL32":
stack_offset = int(symbol.location[1:-1], 16)
stack_symbols.append(
CppStackSymbol(
symbol.name,
symbol.data_type,
stack_offset + stack_offset_delta,
)
)
call_type = self._call_type_map[function_type["call_type"]]
# parse as hex number, default to 0
this_adjust = int(function_type.get("this_adjust", "0"), 16)
return FunctionSignature(
original_function_symbol=fn,
call_type=call_type,
arglist=arg_list_pdb_types,
return_type=function_type["return_type"],
class_type=class_type,
stack_symbols=stack_symbols,
this_adjust=this_adjust,
)
def get_function_list(self) -> list[PdbFunction]:
handled = (
self.handle_matched_function(match)
for match in self.compare.get_functions()
)
return [signature for signature in handled if signature is not None]
def handle_matched_function(self, match_info: MatchInfo) -> Optional[PdbFunction]:
assert match_info.orig_addr is not None
match_options = self.compare.get_match_options(match_info.orig_addr)
assert match_options is not None
function_data = next(
(
y
for y in self.compare.cvdump_analysis.nodes
if y.addr == match_info.recomp_addr
),
None,
)
if function_data is None:
try:
# this can be either a thunk (which we want) or an external function
# (which we don't want), so we tell them apart based on the validity of their address.
self.compare.orig_bin.get_relative_addr(match_info.orig_addr)
return PdbFunction(match_info, None, False)
except InvalidVirtualAddressError:
logger.debug(
"Skipping external function %s (address 0x%x not in original binary)",
match_info.name,
match_info.orig_addr,
)
return None
function_symbol = function_data.symbol_entry
if function_symbol is None:
logger.debug(
"Could not find function symbol (likely a PUBLICS entry): %s",
match_info.name,
)
return None
function_signature = self.get_func_signature(function_symbol)
is_stub = match_options.get("stub", False)
return PdbFunction(match_info, function_signature, is_stub)

View file

@ -1,68 +0,0 @@
from dataclasses import dataclass, field
import logging
from lego_util.exceptions import (
TypeNotFoundInGhidraError,
ClassOrNamespaceNotFoundInGhidraError,
)
logger = logging.getLogger(__name__)
@dataclass
class Statistics:
functions_changed: int = 0
successes: int = 0
failures: dict[str, int] = field(default_factory=dict)
known_missing_types: dict[str, int] = field(default_factory=dict)
known_missing_namespaces: dict[str, int] = field(default_factory=dict)
def track_failure_and_tell_if_new(self, error: Exception) -> bool:
"""
Adds the error to the statistics. Returns `False` if logging the error would be redundant
(e.g. because it is a `TypeNotFoundInGhidraError` with a type that has been logged before).
"""
error_type_name = error.__class__.__name__
self.failures[error_type_name] = (
self.failures.setdefault(error_type_name, 0) + 1
)
if isinstance(error, TypeNotFoundInGhidraError):
return self._add_occurence_and_check_if_new(
self.known_missing_types, error.args[0]
)
if isinstance(error, ClassOrNamespaceNotFoundInGhidraError):
return self._add_occurence_and_check_if_new(
self.known_missing_namespaces, error.get_namespace_str()
)
# We do not have detailed tracking for other errors, so we want to log them every time
return True
def _add_occurence_and_check_if_new(self, target: dict[str, int], key: str) -> bool:
old_count = target.setdefault(key, 0)
target[key] = old_count + 1
return old_count == 0
def log(self):
logger.info("Statistics:\n~~~~~")
logger.info(
"Missing types (with number of occurences): %s\n~~~~~",
self.format_statistics(self.known_missing_types),
)
logger.info(
"Missing classes/namespaces (with number of occurences): %s\n~~~~~",
self.format_statistics(self.known_missing_namespaces),
)
logger.info("Successes: %d", self.successes)
logger.info("Failures: %s", self.failures)
logger.info("Functions changed: %d", self.functions_changed)
def format_statistics(self, stats: dict[str, int]) -> str:
if len(stats) == 0:
return "<none>"
return ", ".join(
f"{entry[0]} ({entry[1]})"
for entry in sorted(stats.items(), key=lambda x: x[1], reverse=True)
)

View file

@ -1,541 +0,0 @@
import logging
from typing import Any, Callable, Iterator, Optional, TypeVar
# Disable spurious warnings in vscode / pylance
# pyright: reportMissingModuleSource=false
# pylint: disable=too-many-return-statements # a `match` would be better, but for now we are stuck with Python 3.9
# pylint: disable=no-else-return # Not sure why this rule even is a thing, this is great for checking exhaustiveness
from isledecomp.cvdump.types import VirtualBasePointer
from lego_util.exceptions import (
ClassOrNamespaceNotFoundInGhidraError,
TypeNotFoundError,
TypeNotFoundInGhidraError,
TypeNotImplementedError,
StructModificationError,
)
from lego_util.ghidra_helper import (
add_data_type_or_reuse_existing,
get_or_add_pointer_type,
create_ghidra_namespace,
get_ghidra_namespace,
get_ghidra_type,
sanitize_name,
)
from lego_util.pdb_extraction import PdbFunctionExtractor
from ghidra.program.flatapi import FlatProgramAPI
from ghidra.program.model.data import (
ArrayDataType,
CategoryPath,
DataType,
DataTypeConflictHandler,
Enum,
EnumDataType,
StructureDataType,
StructureInternal,
TypedefDataType,
ComponentOffsetSettingsDefinition,
)
from ghidra.util.task import ConsoleTaskMonitor
logger = logging.getLogger(__name__)
class PdbTypeImporter:
"""Allows PDB types to be imported into Ghidra."""
def __init__(self, api: FlatProgramAPI, extraction: PdbFunctionExtractor):
self.api = api
self.extraction = extraction
# tracks the structs/classes we have already started to import, otherwise we run into infinite recursion
self.handled_structs: set[str] = set()
# tracks the enums we have already handled for the sake of efficiency
self.handled_enums: dict[str, Enum] = {}
@property
def types(self):
return self.extraction.compare.cv.types
def import_pdb_type_into_ghidra(
self, type_index: str, slim_for_vbase: bool = False
) -> DataType:
"""
Recursively imports a type from the PDB into Ghidra.
@param type_index Either a scalar type like `T_INT4(...)` or a PDB reference like `0x10ba`
@param slim_for_vbase If true, the current invocation
imports a superclass of some class where virtual inheritance is involved (directly or indirectly).
This case requires special handling: Let's say we have `class C: B` and `class B: virtual A`. Then cvdump
reports a size for B that includes both B's fields as well as the A contained at an offset within B,
which is not the correct structure to be contained in C. Therefore, we need to create a "slim" version of B
that fits inside C.
This value should always be `False` when the referenced type is not (a pointer to) a class.
"""
type_index_lower = type_index.lower()
if type_index_lower.startswith("t_"):
return self._import_scalar_type(type_index_lower)
try:
type_pdb = self.extraction.compare.cv.types.keys[type_index_lower]
except KeyError as e:
raise TypeNotFoundError(
f"Failed to find referenced type '{type_index_lower}'"
) from e
type_category = type_pdb["type"]
# follow forward reference (class, struct, union)
if type_pdb.get("is_forward_ref", False):
return self._import_forward_ref_type(
type_index_lower, type_pdb, slim_for_vbase
)
if type_category == "LF_POINTER":
return get_or_add_pointer_type(
self.api,
self.import_pdb_type_into_ghidra(
type_pdb["element_type"], slim_for_vbase
),
)
elif type_category in ["LF_CLASS", "LF_STRUCTURE"]:
return self._import_class_or_struct(type_pdb, slim_for_vbase)
elif type_category == "LF_ARRAY":
return self._import_array(type_pdb)
elif type_category == "LF_ENUM":
return self._import_enum(type_pdb)
elif type_category == "LF_PROCEDURE":
logger.warning(
"Not implemented: Function-valued argument or return type will be replaced by void pointer: %s",
type_pdb,
)
return get_ghidra_type(self.api, "void")
elif type_category == "LF_UNION":
return self._import_union(type_pdb)
else:
raise TypeNotImplementedError(type_pdb)
_scalar_type_map = {
"rchar": "char",
"int4": "int",
"uint4": "uint",
"real32": "float",
"real64": "double",
}
def _scalar_type_to_cpp(self, scalar_type: str) -> str:
if scalar_type.startswith("32p"):
return f"{self._scalar_type_to_cpp(scalar_type[3:])} *"
return self._scalar_type_map.get(scalar_type, scalar_type)
def _import_scalar_type(self, type_index_lower: str) -> DataType:
if (match := self.extraction.scalar_type_regex.match(type_index_lower)) is None:
raise TypeNotFoundError(f"Type has unexpected format: {type_index_lower}")
scalar_cpp_type = self._scalar_type_to_cpp(match.group("typename"))
return get_ghidra_type(self.api, scalar_cpp_type)
def _import_forward_ref_type(
self,
type_index,
type_pdb: dict[str, Any],
slim_for_vbase: bool = False,
) -> DataType:
referenced_type = type_pdb.get("udt") or type_pdb.get("modifies")
if referenced_type is None:
try:
# Example: HWND__, needs to be created manually
return get_ghidra_type(self.api, type_pdb["name"])
except TypeNotFoundInGhidraError as e:
raise TypeNotImplementedError(
f"{type_index}: forward ref without target, needs to be created manually: {type_pdb}"
) from e
logger.debug(
"Following forward reference from %s to %s",
type_index,
referenced_type,
)
return self.import_pdb_type_into_ghidra(referenced_type, slim_for_vbase)
def _import_array(self, type_pdb: dict[str, Any]) -> DataType:
inner_type = self.import_pdb_type_into_ghidra(type_pdb["array_type"])
array_total_bytes: int = type_pdb["size"]
data_type_size = inner_type.getLength()
array_length, modulus = divmod(array_total_bytes, data_type_size)
assert (
modulus == 0
), f"Data type size {data_type_size} does not divide array size {array_total_bytes}"
return ArrayDataType(inner_type, array_length, 0)
def _import_union(self, type_pdb: dict[str, Any]) -> DataType:
try:
logger.debug("Dereferencing union %s", type_pdb)
union_type = get_ghidra_type(self.api, type_pdb["name"])
assert (
union_type.getLength() == type_pdb["size"]
), f"Wrong size of existing union type '{type_pdb['name']}': expected {type_pdb['size']}, got {union_type.getLength()}"
return union_type
except TypeNotFoundInGhidraError as e:
# We have so few instances, it is not worth implementing this
raise TypeNotImplementedError(
f"Writing union types is not supported. Please add by hand: {type_pdb}"
) from e
def _import_enum(self, type_pdb: dict[str, Any]) -> DataType:
underlying_type = self.import_pdb_type_into_ghidra(type_pdb["underlying_type"])
field_list = self.extraction.compare.cv.types.keys.get(type_pdb["field_type"])
assert field_list is not None, f"Failed to find field list for enum {type_pdb}"
result = self._get_or_create_enum_data_type(
type_pdb["name"], underlying_type.getLength()
)
# clear existing variant if there are any
for existing_variant in result.getNames():
result.remove(existing_variant)
variants: list[dict[str, Any]] = field_list["variants"]
for variant in variants:
result.add(variant["name"], variant["value"])
return result
def _import_class_or_struct(
self,
type_in_pdb: dict[str, Any],
slim_for_vbase: bool = False,
) -> DataType:
field_list_type: str = type_in_pdb["field_list_type"]
field_list = self.types.keys[field_list_type.lower()]
class_size: int = type_in_pdb["size"]
class_name_with_namespace: str = sanitize_name(type_in_pdb["name"])
if slim_for_vbase:
class_name_with_namespace += "_vbase_slim"
if class_name_with_namespace in self.handled_structs:
logger.debug(
"Class has been handled or is being handled: %s",
class_name_with_namespace,
)
return get_ghidra_type(self.api, class_name_with_namespace)
logger.debug(
"--- Beginning to import class/struct '%s'", class_name_with_namespace
)
# Add as soon as we start to avoid infinite recursion
self.handled_structs.add(class_name_with_namespace)
self._get_or_create_namespace(class_name_with_namespace)
new_ghidra_struct = self._get_or_create_struct_data_type(
class_name_with_namespace, class_size
)
if (old_size := new_ghidra_struct.getLength()) != class_size:
logger.warning(
"Existing class %s had incorrect size %d. Setting to %d...",
class_name_with_namespace,
old_size,
class_size,
)
logger.info("Adding class data type %s", class_name_with_namespace)
logger.debug("Class information: %s", type_in_pdb)
components: list[dict[str, Any]] = []
components.extend(self._get_components_from_base_classes(field_list))
# can be missing when no new fields are declared
components.extend(self._get_components_from_members(field_list))
components.extend(
self._get_components_from_vbase(
field_list, class_name_with_namespace, new_ghidra_struct
)
)
components.sort(key=lambda c: c["offset"])
if slim_for_vbase:
# Make a "slim" version: shrink the size to the fields that are actually present.
# This makes a difference when the current class uses virtual inheritance
assert (
len(components) > 0
), f"Error: {class_name_with_namespace} should not be empty. There must be at least one direct or indirect vbase pointer."
last_component = components[-1]
class_size = last_component["offset"] + last_component["type"].getLength()
self._overwrite_struct(
class_name_with_namespace,
new_ghidra_struct,
class_size,
components,
)
logger.info("Finished importing class %s", class_name_with_namespace)
return new_ghidra_struct
def _get_components_from_base_classes(self, field_list) -> Iterator[dict[str, Any]]:
non_virtual_base_classes: dict[str, int] = field_list.get("super", {})
for super_type, offset in non_virtual_base_classes.items():
# If we have virtual inheritance _and_ a non-virtual base class here, we play safe and import slim version.
# This is technically not needed if only one of the superclasses uses virtual inheritance, but I am not aware of any instance.
import_slim_vbase_version_of_superclass = "vbase" in field_list
ghidra_type = self.import_pdb_type_into_ghidra(
super_type, slim_for_vbase=import_slim_vbase_version_of_superclass
)
yield {
"type": ghidra_type,
"offset": offset,
"name": "base" if offset == 0 else f"base_{ghidra_type.getName()}",
}
def _get_components_from_members(self, field_list: dict[str, Any]):
members: list[dict[str, Any]] = field_list.get("members") or []
for member in members:
yield member | {"type": self.import_pdb_type_into_ghidra(member["type"])}
def _get_components_from_vbase(
self,
field_list: dict[str, Any],
class_name_with_namespace: str,
current_type: StructureInternal,
) -> Iterator[dict[str, Any]]:
vbasepointer: Optional[VirtualBasePointer] = field_list.get("vbase", None)
if vbasepointer is not None and any(x.direct for x in vbasepointer.bases):
vbaseptr_type = get_or_add_pointer_type(
self.api,
self._import_vbaseptr(
current_type, class_name_with_namespace, vbasepointer
),
)
yield {
"type": vbaseptr_type,
"offset": vbasepointer.vboffset,
"name": "vbase_offset",
}
def _import_vbaseptr(
self,
current_type: StructureInternal,
class_name_with_namespace: str,
vbasepointer: VirtualBasePointer,
) -> StructureInternal:
pointer_size = 4 # hard-code to 4 because of 32 bit
components = [
{
"offset": 0,
"type": get_or_add_pointer_type(self.api, current_type),
"name": "o_self",
}
]
for vbase in vbasepointer.bases:
vbase_ghidra_type = self.import_pdb_type_into_ghidra(vbase.type)
type_name = vbase_ghidra_type.getName()
vbase_ghidra_pointer = get_or_add_pointer_type(self.api, vbase_ghidra_type)
vbase_ghidra_pointer_typedef = TypedefDataType(
vbase_ghidra_pointer.getCategoryPath(),
f"{type_name}PtrOffset",
vbase_ghidra_pointer,
)
# Set a default value of -4 for the pointer offset. While this appears to be correct in many cases,
# it does not always lead to the best decompile. It can be fine-tuned by hand; the next function call
# makes sure that we don't overwrite this value on re-running the import.
ComponentOffsetSettingsDefinition.DEF.setValue(
vbase_ghidra_pointer_typedef.getDefaultSettings(), -4
)
vbase_ghidra_pointer_typedef = add_data_type_or_reuse_existing(
self.api, vbase_ghidra_pointer_typedef
)
components.append(
{
"offset": vbase.index * pointer_size,
"type": vbase_ghidra_pointer_typedef,
"name": f"o_{type_name}",
}
)
size = len(components) * pointer_size
new_ghidra_struct = self._get_or_create_struct_data_type(
f"{class_name_with_namespace}::VBasePtr", size
)
self._overwrite_struct(
f"{class_name_with_namespace}::VBasePtr",
new_ghidra_struct,
size,
components,
)
return new_ghidra_struct
def _overwrite_struct(
self,
class_name_with_namespace: str,
new_ghidra_struct: StructureInternal,
class_size: int,
components: list[dict[str, Any]],
):
new_ghidra_struct.deleteAll()
new_ghidra_struct.growStructure(class_size)
# this case happened e.g. for IUnknown, which linked to an (incorrect) existing library, and some other types as well.
# Unfortunately, we don't get proper error handling for read-only types.
# However, we really do NOT want to do this every time because the type might be self-referential and partially imported.
if new_ghidra_struct.getLength() != class_size:
new_ghidra_struct = self._delete_and_recreate_struct_data_type(
class_name_with_namespace, class_size, new_ghidra_struct
)
for component in components:
offset: int = component["offset"]
logger.debug(
"Adding component %s to class: %s", component, class_name_with_namespace
)
try:
# Make sure there is room for the new structure and that we have no collision.
existing_type = new_ghidra_struct.getComponentAt(offset)
assert (
existing_type is not None
), f"Struct collision: Offset {offset} in {class_name_with_namespace} is overlapped by another component"
if existing_type.getDataType().getName() != "undefined":
# collision of structs beginning in the same place -> likely due to unions
logger.warning(
"Struct collision: Offset %d of %s already has a field (likely an inline union)",
offset,
class_name_with_namespace,
)
new_ghidra_struct.replaceAtOffset(
offset,
component["type"],
-1, # set to -1 for fixed-size components
component["name"], # name
None, # comment
)
except Exception as e:
raise StructModificationError(class_name_with_namespace) from e
def _get_or_create_namespace(self, class_name_with_namespace: str):
colon_split = class_name_with_namespace.split("::")
class_name = colon_split[-1]
try:
get_ghidra_namespace(self.api, colon_split)
logger.debug("Found existing class/namespace %s", class_name_with_namespace)
except ClassOrNamespaceNotFoundInGhidraError:
logger.info("Creating class/namespace %s", class_name_with_namespace)
class_name = colon_split.pop()
parent_namespace = create_ghidra_namespace(self.api, colon_split)
self.api.createClass(parent_namespace, class_name)
def _get_or_create_enum_data_type(
self, enum_type_name: str, enum_type_size: int
) -> Enum:
if (known_enum := self.handled_enums.get(enum_type_name, None)) is not None:
return known_enum
result = self._get_or_create_data_type(
enum_type_name,
"enum",
Enum,
lambda: EnumDataType(
CategoryPath("/imported"), enum_type_name, enum_type_size
),
)
self.handled_enums[enum_type_name] = result
return result
def _get_or_create_struct_data_type(
self, class_name_with_namespace: str, class_size: int
) -> StructureInternal:
return self._get_or_create_data_type(
class_name_with_namespace,
"class/struct",
StructureInternal,
lambda: StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
),
)
T = TypeVar("T", bound=DataType)
def _get_or_create_data_type(
self,
type_name: str,
readable_name_of_type_category: str,
expected_type: type[T],
new_instance_callback: Callable[[], T],
) -> T:
"""
Checks if a data type provided under the given name exists in Ghidra.
Creates one using `new_instance_callback` if there is not.
Also verifies the data type.
Note that the return value of `addDataType()` is not the same instance as the input
even if there is no name collision.
"""
try:
data_type = get_ghidra_type(self.api, type_name)
logger.debug(
"Found existing %s type %s under category path %s",
readable_name_of_type_category,
type_name,
data_type.getCategoryPath(),
)
except TypeNotFoundInGhidraError:
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(
new_instance_callback(), DataTypeConflictHandler.KEEP_HANDLER
)
)
logger.info(
"Created new %s data type %s", readable_name_of_type_category, type_name
)
assert isinstance(
data_type, expected_type
), f"Found existing type named {type_name} that is not a {readable_name_of_type_category}"
return data_type
def _delete_and_recreate_struct_data_type(
self,
class_name_with_namespace: str,
class_size: int,
existing_data_type: DataType,
) -> StructureInternal:
logger.warning(
"Failed to modify data type %s. Will try to delete the existing one and re-create the imported one.",
class_name_with_namespace,
)
assert (
self.api.getCurrentProgram()
.getDataTypeManager()
.remove(existing_data_type, ConsoleTaskMonitor())
), f"Failed to delete and re-create data type {class_name_with_namespace}"
data_type = StructureDataType(
CategoryPath("/imported"), class_name_with_namespace, class_size
)
data_type = (
self.api.getCurrentProgram()
.getDataTypeManager()
.addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER)
)
assert isinstance(data_type, StructureInternal) # for type checking
return data_type

View file

@ -1,2 +0,0 @@
isledecomp.egg-info/
build

View file

@ -1,4 +0,0 @@
from .bin import *
from .dir import *
from .parser import *
from .utils import *

View file

@ -1,574 +0,0 @@
import logging
import struct
import bisect
from functools import cached_property
from typing import Iterator, List, Optional, Tuple
from dataclasses import dataclass
from collections import namedtuple
class MZHeaderNotFoundError(Exception):
"""MZ magic string not found at the start of the binary."""
class PEHeaderNotFoundError(Exception):
"""PE magic string not found at the offset given in 0x3c."""
class SectionNotFoundError(KeyError):
"""The specified section was not found in the file."""
class InvalidVirtualAddressError(IndexError):
"""The given virtual address is too high or low
to point to something in the binary file."""
PEHeader = namedtuple(
"PEHeader",
[
"Signature",
"Machine",
"NumberOfSections",
"TimeDateStamp",
"PointerToSymbolTable", # deprecated
"NumberOfSymbols", # deprecated
"SizeOfOptionalHeader",
"Characteristics",
],
)
ImageSectionHeader = namedtuple(
"ImageSectionHeader",
[
"name",
"virtual_size",
"virtual_address",
"size_of_raw_data",
"pointer_to_raw_data",
"pointer_to_relocations",
"pointer_to_line_numbers",
"number_of_relocations",
"number_of_line_numbers",
"characteristics",
],
)
@dataclass
class Section:
name: str
virtual_size: int
virtual_address: int
view: memoryview
@cached_property
def size_of_raw_data(self) -> int:
return len(self.view)
@cached_property
def extent(self):
"""Get the highest possible offset of this section"""
return max(self.size_of_raw_data, self.virtual_size)
def match_name(self, name: str) -> bool:
return self.name == name
def contains_vaddr(self, vaddr: int) -> bool:
return self.virtual_address <= vaddr < self.virtual_address + self.extent
def read_virtual(self, vaddr: int, size: int) -> memoryview:
ofs = vaddr - self.virtual_address
# Negative index will read from the end, which we don't want
if ofs < 0:
raise InvalidVirtualAddressError
try:
return self.view[ofs : ofs + size]
except IndexError as ex:
raise InvalidVirtualAddressError from ex
def addr_is_uninitialized(self, vaddr: int) -> bool:
"""We cannot rely on the IMAGE_SCN_CNT_UNINITIALIZED_DATA flag (0x80) in
the characteristics field so instead we determine it this way."""
if not self.contains_vaddr(vaddr):
return False
# Should include the case where size_of_raw_data == 0,
# meaning the entire section is uninitialized
return (self.virtual_size > self.size_of_raw_data) and (
vaddr - self.virtual_address >= self.size_of_raw_data
)
logger = logging.getLogger(__name__)
class Bin:
"""Parses a PE format EXE and allows reading data from a virtual address.
Reference: https://learn.microsoft.com/en-us/windows/win32/debug/pe-format"""
# pylint: disable=too-many-instance-attributes
def __init__(self, filename: str, find_str: bool = False) -> None:
logger.debug('Parsing headers of "%s"... ', filename)
self.filename = filename
self.view: memoryview = None
self.imagebase = None
self.entry = None
self.sections: List[Section] = []
self._section_vaddr: List[int] = []
self.find_str = find_str
self._potential_strings = {}
self._relocations = set()
self._relocated_addrs = set()
self.imports = []
self.thunks = []
self.exports: List[Tuple[int, str]] = []
self.is_debug: bool = False
def __enter__(self):
logger.debug("Bin %s Enter", self.filename)
with open(self.filename, "rb") as f:
self.view = memoryview(f.read())
(mz_str,) = struct.unpack("2s", self.view[0:2])
if mz_str != b"MZ":
raise MZHeaderNotFoundError
# Skip to PE header offset in MZ header.
(pe_header_start,) = struct.unpack("<I", self.view[0x3C:0x40])
# PE header offset is absolute, so seek there
pe_header_view = self.view[pe_header_start:]
pe_hdr = PEHeader(*struct.unpack("<2s2x2H3I2H", pe_header_view[:0x18]))
if pe_hdr.Signature != b"PE":
raise PEHeaderNotFoundError
optional_hdr = pe_header_view[0x18:]
(self.imagebase,) = struct.unpack("<i", optional_hdr[0x1C:0x20])
(entry,) = struct.unpack("<i", optional_hdr[0x10:0x14])
self.entry = entry + self.imagebase
(number_of_rva,) = struct.unpack("<i", optional_hdr[0x5C:0x60])
data_dictionaries = [
*struct.iter_unpack("<2I", optional_hdr[0x60 : 0x60 + number_of_rva * 8])
]
# Check for presence of .debug subsection in .rdata
try:
if data_dictionaries[6][0] != 0:
self.is_debug = True
except IndexError:
pass
headers_view = optional_hdr[
pe_hdr.SizeOfOptionalHeader : pe_hdr.SizeOfOptionalHeader
+ 0x28 * pe_hdr.NumberOfSections
]
section_headers = [
ImageSectionHeader(*h) for h in struct.iter_unpack("<8s6I2HI", headers_view)
]
self.sections = [
Section(
name=hdr.name.decode("ascii").rstrip("\x00"),
virtual_address=self.imagebase + hdr.virtual_address,
virtual_size=hdr.virtual_size,
view=self.view[
hdr.pointer_to_raw_data : hdr.pointer_to_raw_data
+ hdr.size_of_raw_data
],
)
for hdr in section_headers
]
# bisect does not support key on the github CI version of python
self._section_vaddr = [section.virtual_address for section in self.sections]
self._populate_relocations()
self._populate_imports()
self._populate_thunks()
# Export dir is always first
self._populate_exports(*data_dictionaries[0])
# This is a (semi) expensive lookup that is not necesssary in every case.
# We can find strings in the original if we have coverage using STRING markers.
# For the recomp, we can find strings using the PDB.
if self.find_str:
self._prepare_string_search()
logger.debug("... Parsing finished")
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
logger.debug("Bin %s Exit", self.filename)
self.view.release()
def get_relocated_addresses(self) -> List[int]:
return sorted(self._relocated_addrs)
def find_string(self, target: str) -> Optional[int]:
# Pad with null terminator to make sure we don't
# match on a subset of the full string
if not target.endswith(b"\x00"):
target += b"\x00"
c = target[0]
if c not in self._potential_strings:
return None
for addr in self._potential_strings[c]:
if target == self.read(addr, len(target)):
return addr
return None
def is_relocated_addr(self, vaddr) -> bool:
return vaddr in self._relocated_addrs
def _prepare_string_search(self):
"""We are intersted in deduplicated string constants found in the
.rdata and .data sections. For each relocated address in these sections,
read the first byte and save the address if that byte is an ASCII character.
When we search for an arbitrary string later, we can narrow down the list
of potential locations by a lot."""
def is_ascii(b):
return b" " <= b < b"\x7f"
sect_data = self.get_section_by_name(".data")
sect_rdata = self.get_section_by_name(".rdata")
potentials = filter(
lambda a: sect_data.contains_vaddr(a) or sect_rdata.contains_vaddr(a),
self.get_relocated_addresses(),
)
for addr in potentials:
c = self.read(addr, 1)
if c is not None and is_ascii(c):
k = ord(c)
if k not in self._potential_strings:
self._potential_strings[k] = set()
self._potential_strings[k].add(addr)
def _populate_relocations(self):
"""The relocation table in .reloc gives each virtual address where the next four
bytes are, itself, another virtual address. During loading, these values will be
patched according to the virtual address space for the image, as provided by Windows.
We can use this information to get a list of where each significant "thing"
in the file is located. Anything that is referenced absolutely (i.e. excluding
jump destinations given by local offset) will be here.
One use case is to tell whether an immediate value in an operand represents
a virtual address or just a big number."""
reloc = self.get_section_by_name(".reloc").view
ofs = 0
reloc_addrs = []
# Parse the structure in .reloc to get the list locations to check.
# The first 8 bytes are 2 dwords that give the base page address
# and the total block size (including this header).
# The page address is used to compact the list; each entry is only
# 2 bytes, and these are added to the base to get the full location.
# If the entry read in is zero, we are at the end of this section and
# these are padding bytes.
while True:
(page_base, block_size) = struct.unpack("<2I", reloc[ofs : ofs + 8])
if block_size == 0:
break
# HACK: ignore the relocation type for now (the top 4 bits of the value).
values = list(struct.iter_unpack("<H", reloc[ofs + 8 : ofs + block_size]))
reloc_addrs += [
self.imagebase + page_base + (v[0] & 0xFFF) for v in values if v[0] != 0
]
ofs += block_size
# We are now interested in the relocated addresses themselves. Seek to the
# address where there is a relocation, then read the four bytes into our set.
reloc_addrs.sort()
self._relocations = set(reloc_addrs)
for section_id, offset in map(self.get_relative_addr, reloc_addrs):
section = self.get_section_by_index(section_id)
(relocated_addr,) = struct.unpack("<I", section.view[offset : offset + 4])
self._relocated_addrs.add(relocated_addr)
def find_float_consts(self) -> Iterator[Tuple[int, int, float]]:
"""Floating point instructions that refer to a memory address can
point to constant values. Search the code sections to find FP
instructions and check whether the pointer address refers to
read-only data."""
# TODO: Should check any section that has code, not just .text
text = self.get_section_by_name(".text")
rdata = self.get_section_by_name(".rdata")
# These are the addresses where a relocation occurs.
# Meaning: it points to an absolute address of something
for addr in self._relocations:
if not text.contains_vaddr(addr):
continue
# Read the two bytes before the relocated address.
# We will check against possible float opcodes
raw = text.read_virtual(addr - 2, 6)
(opcode, opcode_ext, const_addr) = struct.unpack("<BBL", raw)
# Skip right away if this is not const data
if not rdata.contains_vaddr(const_addr):
continue
if opcode_ext in (0x5, 0xD, 0x15, 0x1D, 0x25, 0x2D, 0x35, 0x3D):
if opcode in (0xD8, 0xD9):
# dword ptr -- single precision
(float_value,) = struct.unpack("<f", self.read(const_addr, 4))
yield (const_addr, 4, float_value)
elif opcode in (0xDC, 0xDD):
# qword ptr -- double precision
(float_value,) = struct.unpack("<d", self.read(const_addr, 8))
yield (const_addr, 8, float_value)
def _populate_imports(self):
"""Parse .idata to find imported DLLs and their functions."""
idata_ofs = self.get_section_offset_by_name(".idata")
def iter_image_import():
ofs = idata_ofs
while True:
# Read 5 dwords until all are zero.
image_import_descriptor = struct.unpack("<5I", self.read(ofs, 20))
ofs += 20
if all(x == 0 for x in image_import_descriptor):
break
(rva_ilt, _, __, dll_name, rva_iat) = image_import_descriptor
# Convert relative virtual addresses into absolute
yield (
self.imagebase + rva_ilt,
self.imagebase + dll_name,
self.imagebase + rva_iat,
)
image_import_descriptors = list(iter_image_import())
def iter_imports():
# ILT = Import Lookup Table
# IAT = Import Address Table
# ILT gives us the symbol name of the import.
# IAT gives the address. The compiler generated a thunk function
# that jumps to the value of this address.
for start_ilt, dll_addr, start_iat in image_import_descriptors:
dll_name = self.read_string(dll_addr).decode("ascii")
ofs_ilt = start_ilt
# Address of "__imp__*" symbols.
ofs_iat = start_iat
while True:
(lookup_addr,) = struct.unpack("<L", self.read(ofs_ilt, 4))
(import_addr,) = struct.unpack("<L", self.read(ofs_iat, 4))
if lookup_addr == 0 or import_addr == 0:
break
# MSB set if this is an ordinal import
if lookup_addr & 0x80000000 != 0:
ordinal_num = lookup_addr & 0x7FFF
symbol_name = f"Ordinal_{ordinal_num}"
else:
# Skip the "Hint" field, 2 bytes
name_ofs = lookup_addr + self.imagebase + 2
symbol_name = self.read_string(name_ofs).decode("ascii")
yield (dll_name, symbol_name, ofs_iat)
ofs_ilt += 4
ofs_iat += 4
self.imports = list(iter_imports())
def _populate_thunks(self):
"""For each imported function, we generate a thunk function. The only
instruction in the function is a jmp to the address in .idata.
Search .text to find these functions."""
text_sect = self.get_section_by_name(".text")
text_start = text_sect.virtual_address
# If this is a debug build, read the thunks at the start of .text
# Terminated by a big block of 0xcc padding bytes before the first
# real function in the section.
if self.is_debug:
ofs = 0
while True:
(opcode, operand) = struct.unpack("<Bi", text_sect.view[ofs : ofs + 5])
if opcode != 0xE9:
break
thunk_ofs = text_start + ofs
jmp_ofs = text_start + ofs + 5 + operand
self.thunks.append((thunk_ofs, jmp_ofs))
ofs += 5
# Now check for import thunks which are present in debug and release.
# These use an absolute JMP with the 2 byte opcode: 0xff 0x25
idata_sect = self.get_section_by_name(".idata")
ofs = text_start
for shift in (0, 2, 4):
window = text_sect.view[shift:]
win_end = 6 * (len(window) // 6)
for i, (b0, b1, jmp_ofs) in enumerate(
struct.iter_unpack("<2BL", window[:win_end])
):
if (b0, b1) == (0xFF, 0x25) and idata_sect.contains_vaddr(jmp_ofs):
# Record the address of the jmp instruction and the destination in .idata
thunk_ofs = ofs + shift + i * 6
self.thunks.append((thunk_ofs, jmp_ofs))
def _populate_exports(self, export_rva: int, _: int):
"""If you are missing a lot of annotations in your file
(e.g. debug builds) then you can at least match up the
export symbol names."""
# Null = no exports
if export_rva == 0:
return
export_start = self.imagebase + export_rva
# TODO: namedtuple
export_table = struct.unpack("<2L2H7L", self.read(export_start, 40))
# TODO: if the number of functions doesn't match the number of names,
# are the remaining functions ordinals?
n_functions = export_table[6]
func_start = export_start + 40
func_addrs = [
self.imagebase + rva
for rva, in struct.iter_unpack("<L", self.read(func_start, 4 * n_functions))
]
name_start = func_start + 4 * n_functions
name_addrs = [
self.imagebase + rva
for rva, in struct.iter_unpack("<L", self.read(name_start, 4 * n_functions))
]
combined = zip(func_addrs, name_addrs)
self.exports = [
(func_addr, self.read_string(name_addr))
for (func_addr, name_addr) in combined
]
def iter_string(self, encoding: str = "ascii") -> Iterator[Tuple[int, str]]:
"""Search for possible strings at each verified address in .data."""
section = self.get_section_by_name(".data")
for addr in self._relocated_addrs:
if section.contains_vaddr(addr):
raw = self.read_string(addr)
if raw is None:
continue
try:
string = raw.decode(encoding)
except UnicodeDecodeError:
continue
yield (addr, string)
def get_section_by_name(self, name: str) -> Section:
section = next(
filter(lambda section: section.match_name(name), self.sections),
None,
)
if section is None:
raise SectionNotFoundError
return section
def get_section_by_index(self, index: int) -> Section:
"""Convert 1-based index into 0-based."""
return self.sections[index - 1]
def get_section_extent_by_index(self, index: int) -> int:
return self.get_section_by_index(index).extent
def get_section_offset_by_index(self, index: int) -> int:
"""The symbols output from cvdump gives addresses in this format: AAAA.BBBBBBBB
where A is the index (1-based) into the section table and B is the local offset.
This will return the virtual address for the start of the section at the given index
so you can get the virtual address for whatever symbol you are looking at.
"""
return self.get_section_by_index(index).virtual_address
def get_section_offset_by_name(self, name: str) -> int:
"""Same as above, but use the section name as the lookup"""
section = self.get_section_by_name(name)
return section.virtual_address
def get_abs_addr(self, section: int, offset: int) -> int:
"""Convenience function for converting section:offset pairs from cvdump
into an absolute vaddr."""
return self.get_section_offset_by_index(section) + offset
def get_relative_addr(self, addr: int) -> Tuple[int, int]:
"""Convert an absolute address back into a (section, offset) pair."""
i = bisect.bisect_right(self._section_vaddr, addr) - 1
i = max(0, i)
section = self.sections[i]
if section.contains_vaddr(addr):
return (i + 1, addr - section.virtual_address)
raise InvalidVirtualAddressError(f"{self.filename} : {hex(addr)}")
def is_valid_section(self, section_id: int) -> bool:
"""The PDB will refer to sections that are not listed in the headers
and so should ignore these references."""
try:
_ = self.get_section_by_index(section_id)
return True
except IndexError:
return False
def is_valid_vaddr(self, vaddr: int) -> bool:
"""Does this virtual address point to anything in the exe?"""
try:
(_, __) = self.get_relative_addr(vaddr)
except InvalidVirtualAddressError:
return False
return True
def read_string(self, offset: int, chunk_size: int = 1000) -> Optional[bytes]:
"""Read until we find a zero byte."""
b = self.read(offset, chunk_size)
if b is None:
return None
try:
return b[: b.index(b"\x00")]
except ValueError:
# No terminator found, just return what we have
return b
def read(self, vaddr: int, size: int) -> Optional[bytes]:
"""Read (at most) the given number of bytes at the given virtual address.
If we return None, the given address points to uninitialized data."""
(section_id, offset) = self.get_relative_addr(vaddr)
section = self.sections[section_id - 1]
if section.addr_is_uninitialized(vaddr):
return None
# Clamp the read within the extent of the current section.
# Reading off the end will most likely misrepresent the virtual addressing.
_size = min(size, section.size_of_raw_data - offset)
return bytes(section.view[offset : offset + _size])

View file

@ -1 +0,0 @@
from .core import Compare

View file

@ -1,2 +0,0 @@
from .parse import ParseAsm
from .swap import can_resolve_register_differences

View file

@ -1,27 +0,0 @@
# Duplicates removed, according to the mnemonics capstone uses.
# e.g. je and jz are the same instruction. capstone uses je.
# See: /arch/X86/X86GenAsmWriter.inc in the capstone repo.
JUMP_MNEMONICS = {
"ja",
"jae",
"jb",
"jbe",
"jcxz", # unused?
"je",
"jecxz",
"jg",
"jge",
"jl",
"jle",
"jmp",
"jne",
"jno",
"jnp",
"jns",
"jo",
"jp",
"js",
}
# Guaranteed to be a single operand.
SINGLE_OPERAND_INSTS = {"push", "call", *JUMP_MNEMONICS}

View file

@ -1,314 +0,0 @@
import re
from typing import List, Tuple, Set
DiffOpcode = Tuple[str, int, int, int, int]
REG_FIND = re.compile(r"(?: |\[)(e?[a-d]x|e?[s,d]i|[a-d][l,h]|e?[b,s]p)")
ALLOWED_JUMP_SWAPS = (
("ja", "jb"),
("jae", "jbe"),
("jb", "ja"),
("jbe", "jae"),
("jg", "jl"),
("jge", "jle"),
("jl", "jg"),
("jle", "jge"),
("je", "je"),
("jne", "jne"),
)
def jump_swap_ok(a: str, b: str) -> bool:
"""For the instructions a,b, are they both jump instructions
that are compatible with a swapped cmp operand order?"""
# Grab the mnemonic
(jmp_a, _, __) = a.partition(" ")
(jmp_b, _, __) = b.partition(" ")
return (jmp_a, jmp_b) in ALLOWED_JUMP_SWAPS
def is_operand_swap(a: str, b: str) -> bool:
"""This is a hack to avoid parsing the operands. It's not as simple as
breaking on the comma because templates or string literals interfere
with this. Instead we check:
1. Do both strings use the exact same set of characters?
2. If we do break on ', ', is the first token of each different?
2 is needed to catch an edge case like:
cmp eax, dword ptr [ecx + 0x1234]
cmp ecx, dword ptr [eax + 0x1234]
"""
return a.partition(", ")[0] != b.partition(", ")[0] and sorted(a) == sorted(b)
def can_cmp_swap(orig: List[str], recomp: List[str]) -> bool:
# Make sure we have 1 cmp and 1 jmp for both
if len(orig) != 2 or len(recomp) != 2:
return False
if not orig[0].startswith("cmp") or not recomp[0].startswith("cmp"):
return False
if not orig[1].startswith("j") or not recomp[1].startswith("j"):
return False
# Checking two things:
# Are the cmp operands flipped?
# Is the jump instruction compatible with a flip?
return is_operand_swap(orig[0], recomp[0]) and jump_swap_ok(orig[1], recomp[1])
def patch_jump(a: str, b: str) -> str:
"""For jump instructions a, b, return `(mnemonic_a) (operand_b)`.
The reason to do it this way (instead of just returning `a`) is that
the jump instructions might use different displacement offsets
or labels. If we just replace `b` with `a`, this diff would be
incorrectly eliminated."""
(mnemonic_a, _, __) = a.partition(" ")
(_, __, operand_b) = b.partition(" ")
return mnemonic_a + " " + operand_b
def patch_cmp_swaps(
codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str]
) -> Set[int]:
"""Can we resolve the diffs between orig and recomp by patching
swapped cmp instructions?
For example:
cmp eax, ebx cmp ebx, eax
je .label je .label
cmp eax, ebx cmp ebx, eax
ja .label jb .label
"""
fixed_lines = set()
for code, i1, i2, j1, j2 in codes:
# To save us the trouble of finding "compatible" cmp instructions
# use the diff information we already have.
if code != "replace":
continue
# If the ranges in orig and recomp are not equal, use the shorter one
for i, j in zip(range(i1, i2), range(j1, j2)):
if can_cmp_swap(orig_asm[i : i + 2], recomp_asm[j : j + 2]):
# Patch cmp
fixed_lines.add(j)
# Patch the jump if necessary
patched = patch_jump(orig_asm[i + 1], recomp_asm[j + 1])
# We only register a fix if it actually matches
if orig_asm[i + 1] == patched:
fixed_lines.add(j + 1)
return fixed_lines
def effective_match_possible(orig_asm: List[str], recomp_asm: List[str]) -> bool:
# We can only declare an effective match based on the text
# so you need the same amount of "stuff" in each
if len(orig_asm) != len(recomp_asm):
return False
# mnemonic_orig = [inst.partition(" ")[0] for inst in orig_asm]
# mnemonic_recomp = [inst.partition(" ")[0] for inst in recomp_asm]
# Cannot change mnemonics. Must be same starting list
# TODO: Fine idea but this will exclude jump swaps for cmp operand order
# if sorted(mnemonic_orig) != sorted(mnemonic_recomp):
# return False
return True
def find_regs_used(inst: str) -> List[str]:
return REG_FIND.findall(inst)
def find_regs_changed(a: str, b: str) -> List[Tuple[str, str]]:
"""For instructions a, b, return the pairs of registers that were used.
This is not a very precise way to compare the instructions, so it depends
on the input being two instructions that would match *except* for
the register choice."""
return zip(REG_FIND.findall(a), REG_FIND.findall(b))
def bad_register_swaps(
swaps: Set[int], orig_asm: List[str], recomp_asm: List[str]
) -> Set[int]:
"""The list of recomp indices in `swaps` tells which instructions are
a match for orig except for the registers used. From that list, check
whether a register swap should not be allowed.
For now, this means checking for `push` instructions where the register
was not used in any other register swaps on previous instructions."""
rejects = set()
# Foreach `push` instruction where we have excused the diff
pushes = [j for j in swaps if recomp_asm[j].startswith("push")]
for j in pushes:
okay = False
# Get the operands in each
reg = (orig_asm[j].partition(" ")[2], recomp_asm[j].partition(" ")[2])
# If this isn't a register at all, ignore it
try:
int(reg[0], 16)
continue
except ValueError:
pass
# For every other excused diff that is *not* a push:
# Assumes same index in orig as in recomp, but so does our naive match
for k in swaps.difference(pushes):
changed_regs = find_regs_changed(orig_asm[k], recomp_asm[k])
if reg in changed_regs or reg[::-1] in changed_regs:
okay = True
break
if not okay:
rejects.add(j)
return rejects
# Instructions that result in a change to the first operand
MODIFIER_INSTRUCTIONS = ("adc", "add", "lea", "mov", "neg", "sbb", "sub", "pop", "xor")
def instruction_alters_regs(inst: str, regs: Set[str]) -> bool:
(mnemonic, _, op_str) = inst.partition(" ")
(first_operand, _, __) = op_str.partition(", ")
return (mnemonic in MODIFIER_INSTRUCTIONS and first_operand in regs) or (
mnemonic == "call" and "eax" in regs
)
def relocate_instructions(
codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str]
) -> Set[int]:
"""Collect the list of instructions deleted from orig and inserted
into recomp, according to the diff opcodes. Using this list, match up
any pairs of instructions that we assume to be relocated and return
the indices in recomp where this has occurred.
For now, we are checking only for an exact match on the instruction.
We are not checking whether the given instruction can be moved from
point A to B. (i.e. does this set a register that is used by the
instructions between A and B?)"""
deletes = {
i for code, i1, i2, _, __ in codes for i in range(i1, i2) if code == "delete"
}
inserts = [
j for code, _, __, j1, j2 in codes for j in range(j1, j2) if code == "insert"
]
relocated = set()
for j in inserts:
line = recomp_asm[j]
recomp_regs_used = set(find_regs_used(line))
for i in deletes:
# Check for exact match.
# TODO: This will grab the first instruction that matches.
# We should probably use the nearest index instead, if it matters
if orig_asm[i] == line:
# To account for a move in either direction
reloc_start = min(i, j)
reloc_end = max(i, j)
if not any(
instruction_alters_regs(orig_asm[k], recomp_regs_used)
for k in range(reloc_start, reloc_end)
):
relocated.add(j)
deletes.remove(i)
break
return relocated
DWORD_REGS = ("eax", "ebx", "ecx", "edx", "esi", "edi", "ebp", "esp")
WORD_REGS = ("ax", "bx", "cx", "dx", "si", "di", "bp", "sp")
BYTE_REGS = ("ah", "al", "bh", "bl", "ch", "cl", "dh", "dl")
def naive_register_replacement(orig_asm: List[str], recomp_asm: List[str]) -> Set[int]:
"""Replace all registers of the same size with a placeholder string.
After doing that, compare orig and recomp again.
Return indices from recomp that are now equal to the same index in orig.
This requires orig and recomp to have the same number of instructions,
but this is already a requirement for effective match."""
orig_raw = "\n".join(orig_asm)
recomp_raw = "\n".join(recomp_asm)
# TODO: hardly the most elegant way to do this.
for rdw in DWORD_REGS:
orig_raw = orig_raw.replace(rdw, "~reg4")
recomp_raw = recomp_raw.replace(rdw, "~reg4")
for rw in WORD_REGS:
orig_raw = orig_raw.replace(rw, "~reg2")
recomp_raw = recomp_raw.replace(rw, "~reg2")
for rb in BYTE_REGS:
orig_raw = orig_raw.replace(rb, "~reg1")
recomp_raw = recomp_raw.replace(rb, "~reg1")
orig_scrubbed = orig_raw.split("\n")
recomp_scrubbed = recomp_raw.split("\n")
return {
j for j in range(len(recomp_scrubbed)) if orig_scrubbed[j] == recomp_scrubbed[j]
}
def find_effective_match(
codes: List[DiffOpcode], orig_asm: List[str], recomp_asm: List[str]
) -> bool:
"""Check whether the two sequences of instructions are an effective match.
Meaning: do they differ only by instruction order or register selection?"""
if not effective_match_possible(orig_asm, recomp_asm):
return False
already_equal = {
j for code, _, __, j1, j2 in codes for j in range(j1, j2) if code == "equal"
}
# We need to come up with some answer for each of these lines
recomp_lines_disputed = {
j
for code, _, __, j1, j2 in codes
for j in range(j1, j2)
if code in ("insert", "replace")
}
cmp_swaps = patch_cmp_swaps(codes, orig_asm, recomp_asm)
# This naive result includes lines that already match, so remove those
naive_swaps = naive_register_replacement(orig_asm, recomp_asm).difference(
already_equal
)
relocates = relocate_instructions(codes, orig_asm, recomp_asm)
bad_swaps = bad_register_swaps(naive_swaps, orig_asm, recomp_asm)
corrections = set().union(
naive_swaps.difference(bad_swaps),
cmp_swaps,
relocates,
)
return corrections.issuperset(recomp_lines_disputed)
def assert_fixup(asm: List[Tuple[str, str]]):
"""Detect assert calls and replace the code filename and line number
values with macros (from assert.h)."""
for i, (_, line) in enumerate(asm):
if "_assert" in line and line.startswith("call"):
try:
asm[i - 3] = (asm[i - 3][0], "push __LINE__")
asm[i - 2] = (asm[i - 2][0], "push __FILE__")
except IndexError:
continue

View file

@ -1,249 +0,0 @@
"""Pre-parser for x86 instructions. Will identify data/jump tables used with
switch statements and local jump/call destinations."""
import re
import bisect
import struct
from enum import Enum, auto
from collections import namedtuple
from typing import Iterable, List, NamedTuple, Optional, Tuple, Union
from capstone import Cs, CS_ARCH_X86, CS_MODE_32
from .const import JUMP_MNEMONICS
disassembler = Cs(CS_ARCH_X86, CS_MODE_32)
DisasmLiteTuple = Tuple[int, int, str, str]
DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str")
displacement_regex = re.compile(r".*\+ (0x[0-9a-f]+)\]")
class SectionType(Enum):
CODE = auto()
DATA_TAB = auto()
ADDR_TAB = auto()
class FuncSection(NamedTuple):
type: SectionType
contents: List[Union[DisasmLiteInst, Tuple[str, int]]]
def stop_at_int3(
disasm_lite_gen: Iterable[DisasmLiteTuple],
) -> Iterable[DisasmLiteTuple]:
"""Wrapper for capstone disasm_lite generator. We want to stop reading
instructions if we hit the int3 instruction."""
for inst in disasm_lite_gen:
# inst[2] is the mnemonic
if inst[2] == "int3":
break
yield inst
class InstructGen:
# pylint: disable=too-many-instance-attributes
def __init__(self, blob: bytes, start: int) -> None:
self.blob = blob
self.start = start
self.end = len(blob) + start
self.section_end: int = self.end
self.code_tracks: List[List[DisasmLiteInst]] = []
# Todo: Could be refactored later
self.cur_addr: int = 0
self.cur_section_type: SectionType = SectionType.CODE
self.section_start = start
self.sections: List[FuncSection] = []
self.confirmed_addrs = {}
self.analysis()
def _finish_section(self, type_: SectionType, stuff):
sect = FuncSection(type_, stuff)
self.sections.append(sect)
def _insert_confirmed_addr(self, addr: int, type_: SectionType):
# Ignore address outside the bounds of the function
if not self.start <= addr < self.end:
return
self.confirmed_addrs[addr] = type_
# This newly inserted address might signal the end of this section.
# For example, a jump table at the end of the function means we should
# stop reading instructions once we hit that address.
# However, if there is a jump table in between code sections, we might
# read a jump to an address back to the beginning of the function
# (e.g. a loop that spans the entire function)
# so ignore this address because we have already passed it.
if type_ != self.cur_section_type and addr > self.cur_addr:
self.section_end = min(self.section_end, addr)
def _next_section(self, addr: int) -> Optional[SectionType]:
"""We have reached the start of a new section. Tell what kind of
data we are looking at (code or other) and how much we should read."""
# Assume the start of every function is code.
if addr == self.start:
self.section_end = self.end
return SectionType.CODE
# The start of a new section must be an address that we've seen.
new_type = self.confirmed_addrs.get(addr)
if new_type is None:
return None
self.cur_section_type = new_type
# The confirmed addrs dict is sorted by insertion order
# i.e. the order in which we read the addresses
# So we have to sort and then find the next item
# to see where this section should end.
# If we are in a CODE section, ignore contiguous CODE addresses.
# These are not the start of a new section.
# However: if we are not in CODE, any upcoming address is a new section.
# Do this so we can detect contiguous non-CODE sections.
confirmed = [
conf_addr
for (conf_addr, conf_type) in sorted(self.confirmed_addrs.items())
if self.cur_section_type != SectionType.CODE
or conf_type != self.cur_section_type
]
index = bisect.bisect_right(confirmed, addr)
if index < len(confirmed):
self.section_end = confirmed[index]
else:
self.section_end = self.end
return new_type
def _get_code_for(self, addr: int) -> List[DisasmLiteInst]:
"""Start disassembling at the given address."""
# If we are reading a code block beyond the first, see if we already
# have disassembled instructions beginning at the specified address.
# For a CODE/ADDR/CODE function, we might get lucky and produce the
# correct instruction after the jump table's junk instructions.
for track in self.code_tracks:
for i, inst in enumerate(track):
if inst.address == addr:
return track[i:]
# If we are here, we don't have the instructions.
# Todo: Could try to be clever here and disassemble only
# as much as we probably need (i.e. if a jump table is between CODE
# blocks, there are probably only a few bad instructions after the
# jump table is finished. We could disassemble up to the next verified
# code address and stitch it together)
blob_cropped = self.blob[addr - self.start :]
instructions = [
DisasmLiteInst(*inst)
for inst in stop_at_int3(disassembler.disasm_lite(blob_cropped, addr))
]
self.code_tracks.append(instructions)
return instructions
def _handle_jump(self, inst: DisasmLiteInst):
# If this is a regular jump and its destination is within the
# bounds of the binary data (i.e. presumed function size)
# add it to our list of confirmed addresses.
if inst.op_str[0] == "0":
value = int(inst.op_str, 16)
self._insert_confirmed_addr(value, SectionType.CODE)
# If this is jumping into a table of addresses, save the destination
elif (match := displacement_regex.match(inst.op_str)) is not None:
value = int(match.group(1), 16)
self._insert_confirmed_addr(value, SectionType.ADDR_TAB)
def analysis(self):
self.cur_addr = self.start
while (sect_type := self._next_section(self.cur_addr)) is not None:
self.section_start = self.cur_addr
if sect_type == SectionType.CODE:
instructions = self._get_code_for(self.cur_addr)
# If we didn't get any instructions back, something is wrong.
# i.e. We can only read part of the full instruction that is up next.
if len(instructions) == 0:
# Nudge the current addr so we will eventually move on to the
# next section.
# Todo: Maybe we could just call it quits here
self.cur_addr += 1
break
for inst in instructions:
# section_end is updated as we read instructions.
# If we are into a jump/data table and would read
# a junk instruction, stop here.
if self.cur_addr >= self.section_end:
break
# print(f"{inst.address:x} : {inst.mnemonic} {inst.op_str}")
if inst.mnemonic in JUMP_MNEMONICS:
self._handle_jump(inst)
# Todo: log calls too (unwind section)
elif inst.mnemonic == "mov":
# Todo: maintain pairing of data/jump tables
if (match := displacement_regex.match(inst.op_str)) is not None:
value = int(match.group(1), 16)
self._insert_confirmed_addr(value, SectionType.DATA_TAB)
# Do this instead of copying instruction address.
# If there is only one instruction, we would get stuck here.
self.cur_addr += inst.size
# End of for loop on instructions.
# We are at the end of the section or the entire function.
# Cut out only the valid instructions for this section
# and save it for later.
# Todo: don't need to iter on every instruction here.
# They are already in order.
instruction_slice = [
inst for inst in instructions if inst.address < self.section_end
]
self._finish_section(SectionType.CODE, instruction_slice)
elif sect_type == SectionType.ADDR_TAB:
# Clamp to multiple of 4 (dwords)
read_size = ((self.section_end - self.cur_addr) // 4) * 4
offsets = range(self.section_start, self.section_start + read_size, 4)
dwords = self.blob[
self.cur_addr - self.start : self.cur_addr - self.start + read_size
]
addrs = [addr for addr, in struct.iter_unpack("<L", dwords)]
for addr in addrs:
# Todo: the fact that these are jump table destinations
# should factor into the label name.
self._insert_confirmed_addr(addr, SectionType.CODE)
jump_table = list(zip(offsets, addrs))
# for (t0,t1) in jump_table:
# print(f"{t0:x} : --> {t1:x}")
self._finish_section(SectionType.ADDR_TAB, jump_table)
self.cur_addr = self.section_end
else:
# Todo: variable data size?
read_size = self.section_end - self.cur_addr
offsets = range(self.section_start, self.section_start + read_size)
bytes_ = self.blob[
self.cur_addr - self.start : self.cur_addr - self.start + read_size
]
data = [b for b, in struct.iter_unpack("<B", bytes_)]
data_table = list(zip(offsets, data))
# for (t0,t1) in data_table:
# print(f"{t0:x} : value {t1:02x}")
self._finish_section(SectionType.DATA_TAB, data_table)
self.cur_addr = self.section_end

View file

@ -1,243 +0,0 @@
"""Converts x86 machine code into text (i.e. assembly). The end goal is to
compare the code in the original and recomp binaries, using longest common
subsequence (LCS), i.e. difflib.SequenceMatcher.
The capstone library takes the raw bytes and gives us the mnemonic
and operand(s) for each instruction. We need to "sanitize" the text further
so that virtual addresses are replaced by symbol name or a generic
placeholder string."""
import re
import struct
from functools import cache
from typing import Callable, List, Optional, Tuple
from collections import namedtuple
from .const import JUMP_MNEMONICS, SINGLE_OPERAND_INSTS
from .instgen import InstructGen, SectionType
ptr_replace_regex = re.compile(r"\[(0x[0-9a-f]+)\]")
displace_replace_regex = re.compile(r"\+ (0x[0-9a-f]+)\]")
# For matching an immediate value on its own.
# Preceded by start-of-string (first operand) or comma-space (second operand)
immediate_replace_regex = re.compile(r"(?:^|, )(0x[0-9a-f]+)")
DisasmLiteInst = namedtuple("DisasmLiteInst", "address, size, mnemonic, op_str")
@cache
def from_hex(string: str) -> Optional[int]:
try:
return int(string, 16)
except ValueError:
pass
return None
def bytes_to_dword(b: bytes) -> Optional[int]:
if len(b) == 4:
return struct.unpack("<L", b)[0]
return None
class ParseAsm:
def __init__(
self,
relocate_lookup: Optional[Callable[[int], bool]] = None,
name_lookup: Optional[Callable[[int, bool], str]] = None,
bin_lookup: Optional[Callable[[int, int], Optional[bytes]]] = None,
) -> None:
self.relocate_lookup = relocate_lookup
self.name_lookup = name_lookup
self.bin_lookup = bin_lookup
self.replacements = {}
self.number_placeholders = True
def reset(self):
self.replacements = {}
def is_relocated(self, addr: int) -> bool:
if callable(self.relocate_lookup):
return self.relocate_lookup(addr)
return False
def lookup(
self, addr: int, use_cache: bool = True, exact: bool = False
) -> Optional[str]:
"""Return a replacement name for this address if we find one."""
if use_cache and (cached := self.replacements.get(addr, None)) is not None:
return cached
if callable(self.name_lookup):
if (name := self.name_lookup(addr, exact)) is not None:
if use_cache:
self.replacements[addr] = name
return name
return None
def replace(self, addr: int) -> str:
"""Same function as lookup above, but here we return a placeholder
if there is no better name to use."""
if (name := self.lookup(addr)) is not None:
return name
# The placeholder number corresponds to the number of addresses we have
# already replaced. This is so the number will be consistent across the diff
# if we can replace some symbols with actual names in recomp but not orig.
idx = len(self.replacements) + 1
placeholder = f"<OFFSET{idx}>" if self.number_placeholders else "<OFFSET>"
self.replacements[addr] = placeholder
return placeholder
def hex_replace_always(self, match: re.Match) -> str:
"""If a pointer value was matched, always insert a placeholder"""
value = int(match.group(1), 16)
return match.group(0).replace(match.group(1), self.replace(value))
def hex_replace_relocated(self, match: re.Match) -> str:
"""For replacing immediate value operands. We only want to
use the placeholder if we are certain that this is a valid address.
We can check the relocation table to find out."""
value = int(match.group(1), 16)
if self.is_relocated(value):
return match.group(0).replace(match.group(1), self.replace(value))
return match.group(0)
def hex_replace_annotated(self, match: re.Match) -> str:
"""For replacing immediate value operands. Here we replace the value
only if the name lookup returns something. Do not use a placeholder."""
value = int(match.group(1), 16)
placeholder = self.lookup(value, use_cache=False)
if placeholder is not None:
return match.group(0).replace(match.group(1), placeholder)
return match.group(0)
def hex_replace_indirect(self, match: re.Match) -> str:
"""Edge case for hex_replace_always. The context of the instruction
tells us that the pointer value is an absolute indirect.
So we go to that location in the binary to get the address.
If we cannot identify the indirect address, fall back to a lookup
on the original pointer value so we might display something useful."""
value = int(match.group(1), 16)
indirect_value = None
if callable(self.bin_lookup):
indirect_value = self.bin_lookup(value, 4)
if indirect_value is not None:
indirect_addr = bytes_to_dword(indirect_value)
if (
indirect_addr is not None
and self.lookup(indirect_addr, use_cache=False) is not None
):
return match.group(0).replace(
match.group(1), "->" + self.replace(indirect_addr)
)
return match.group(0).replace(match.group(1), self.replace(value))
def sanitize(self, inst: DisasmLiteInst) -> Tuple[str, str]:
# For jumps or calls, if the entire op_str is a hex number, the value
# is a relative offset.
# Otherwise (i.e. it looks like `dword ptr [address]`) it is an
# absolute indirect that we will handle below.
# Providing the starting address of the function to capstone.disasm has
# automatically resolved relative offsets to an absolute address.
# We will have to undo this for some of the jumps or they will not match.
if (
inst.mnemonic in SINGLE_OPERAND_INSTS
and (op_str_address := from_hex(inst.op_str)) is not None
):
if inst.mnemonic == "call":
return (inst.mnemonic, self.replace(op_str_address))
if inst.mnemonic == "push":
if self.is_relocated(op_str_address):
return (inst.mnemonic, self.replace(op_str_address))
# To avoid falling into jump handling
return (inst.mnemonic, inst.op_str)
if inst.mnemonic == "jmp":
# The unwind section contains JMPs to other functions.
# If we have a name for this address, use it. If not,
# do not create a new placeholder. We will instead
# fall through to generic jump handling below.
potential_name = self.lookup(op_str_address, exact=True)
if potential_name is not None:
return (inst.mnemonic, potential_name)
# Else: this is any jump
# Show the jump offset rather than the absolute address
jump_displacement = op_str_address - (inst.address + inst.size)
return (inst.mnemonic, hex(jump_displacement))
if inst.mnemonic == "call":
# Special handling for absolute indirect CALL.
op_str = ptr_replace_regex.sub(self.hex_replace_indirect, inst.op_str)
else:
op_str = ptr_replace_regex.sub(self.hex_replace_always, inst.op_str)
# We only want relocated addresses for pointer displacement.
# i.e. ptr [register + something]
# Otherwise we would use a placeholder for every stack variable,
# vtable call, or this->member access.
op_str = displace_replace_regex.sub(self.hex_replace_relocated, op_str)
# In the event of pointer comparison, only replace the immediate value
# if it is a known address.
if inst.mnemonic == "cmp":
op_str = immediate_replace_regex.sub(self.hex_replace_annotated, op_str)
else:
op_str = immediate_replace_regex.sub(self.hex_replace_relocated, op_str)
return (inst.mnemonic, op_str)
def parse_asm(self, data: bytes, start_addr: Optional[int] = 0) -> List[str]:
asm = []
ig = InstructGen(data, start_addr)
for sect_type, sect_contents in ig.sections:
if sect_type == SectionType.CODE:
for inst in sect_contents:
# Use heuristics to disregard some differences that aren't representative
# of the accuracy of a function (e.g. global offsets)
# If there is no pointer or immediate value in the op_str,
# there is nothing to sanitize.
# This leaves us with cases where a small immediate value or
# small displacement (this.member or vtable calls) appears.
# If we assume that instructions we want to sanitize need to be 5
# bytes -- 1 for the opcode and 4 for the address -- exclude cases
# where the hex value could not be an address.
# The exception is jumps which are as small as 2 bytes
# but are still useful to sanitize.
if "0x" in inst.op_str and (
inst.mnemonic in JUMP_MNEMONICS or inst.size > 4
):
result = self.sanitize(inst)
else:
result = (inst.mnemonic, inst.op_str)
# mnemonic + " " + op_str
asm.append((hex(inst.address), " ".join(result)))
elif sect_type == SectionType.ADDR_TAB:
asm.append(("", "Jump table:"))
for i, (ofs, _) in enumerate(sect_contents):
asm.append((hex(ofs), f"Jump_dest_{i}"))
elif sect_type == SectionType.DATA_TAB:
asm.append(("", "Data table:"))
for ofs, b in sect_contents:
asm.append((hex(ofs), hex(b)))
return asm

View file

@ -1,80 +0,0 @@
import re
REGISTER_LIST = set(
[
"ax",
"bp",
"bx",
"cx",
"di",
"dx",
"eax",
"ebp",
"ebx",
"ecx",
"edi",
"edx",
"esi",
"esp",
"si",
"sp",
]
)
WORDS = re.compile(r"\w+")
def get_registers(line: str):
to_replace = []
# use words regex to find all matching positions:
for match in WORDS.finditer(line):
reg = match.group(0)
if reg in REGISTER_LIST:
to_replace.append((reg, match.start()))
return to_replace
def replace_register(
lines: list[str], start_line: int, reg: str, replacement: str
) -> list[str]:
return [
line.replace(reg, replacement) if i >= start_line else line
for i, line in enumerate(lines)
]
# Is it possible to make new_asm the same as original_asm by swapping registers?
def can_resolve_register_differences(original_asm, new_asm):
# Split the ASM on spaces to get more granularity, and so
# that we don't modify the original arrays passed in.
original_asm = [part for line in original_asm for part in line.split()]
new_asm = [part for line in new_asm for part in line.split()]
# Swapping ain't gonna help if the lengths are different
if len(original_asm) != len(new_asm):
return False
# Look for the mismatching lines
for i, original_line in enumerate(original_asm):
new_line = new_asm[i]
if new_line != original_line:
# Find all the registers to replace
to_replace = get_registers(original_line)
for replace in to_replace:
(reg, reg_index) = replace
replacing_reg = new_line[reg_index : reg_index + len(reg)]
if replacing_reg in REGISTER_LIST:
if replacing_reg != reg:
# Do a three-way swap replacing in all the subsequent lines
temp_reg = "&" * len(reg)
new_asm = replace_register(new_asm, i, replacing_reg, temp_reg)
new_asm = replace_register(new_asm, i, reg, replacing_reg)
new_asm = replace_register(new_asm, i, temp_reg, reg)
else:
# No replacement to do, different code, bail out
return False
# Check if the lines are now the same
for i, original_line in enumerate(original_asm):
if new_asm[i] != original_line:
return False
return True

View file

@ -1,921 +0,0 @@
import os
import logging
import difflib
import struct
import uuid
from dataclasses import dataclass
from typing import Any, Callable, Iterable, List, Optional
from isledecomp.bin import Bin as IsleBin, InvalidVirtualAddressError
from isledecomp.cvdump.demangler import demangle_string_const
from isledecomp.cvdump import Cvdump, CvdumpAnalysis
from isledecomp.cvdump.types import scalar_type_pointer
from isledecomp.parser import DecompCodebase
from isledecomp.dir import walk_source_dir
from isledecomp.types import SymbolType
from isledecomp.compare.asm import ParseAsm
from isledecomp.compare.asm.fixes import assert_fixup, find_effective_match
from .db import CompareDb, MatchInfo
from .diff import combined_diff, CombinedDiffOutput
from .lines import LinesDb
logger = logging.getLogger(__name__)
@dataclass
class DiffReport:
# pylint: disable=too-many-instance-attributes
match_type: SymbolType
orig_addr: int
recomp_addr: int
name: str
udiff: Optional[CombinedDiffOutput] = None
ratio: float = 0.0
is_effective_match: bool = False
is_stub: bool = False
@property
def effective_ratio(self) -> float:
return 1.0 if self.is_effective_match else self.ratio
def __str__(self) -> str:
"""For debug purposes. Proper diff printing (with coloring) is in another module."""
return f"{self.name} (0x{self.orig_addr:x}) {self.ratio*100:.02f}%{'*' if self.is_effective_match else ''}"
def create_reloc_lookup(bin_file: IsleBin) -> Callable[[int], bool]:
"""Function generator for relocation table lookup"""
def lookup(addr: int) -> bool:
return addr > bin_file.imagebase and bin_file.is_relocated_addr(addr)
return lookup
def create_bin_lookup(bin_file: IsleBin) -> Callable[[int, int], Optional[str]]:
"""Function generator for reading from the bin file"""
def lookup(addr: int, size: int) -> Optional[bytes]:
try:
return bin_file.read(addr, size)
except InvalidVirtualAddressError:
return None
return lookup
class Compare:
# pylint: disable=too-many-instance-attributes
def __init__(
self, orig_bin: IsleBin, recomp_bin: IsleBin, pdb_file: str, code_dir: str
):
self.orig_bin = orig_bin
self.recomp_bin = recomp_bin
self.pdb_file = pdb_file
self.code_dir = code_dir
# Controls whether we dump the asm output to a file
self.debug: bool = False
self.runid: str = uuid.uuid4().hex[:8]
self._lines_db = LinesDb(code_dir)
self._db = CompareDb()
self._load_cvdump()
self._load_markers()
# Detect floats first to eliminate potential overlap with string data
self._find_float_const()
self._find_original_strings()
self._match_imports()
self._match_exports()
self._match_thunks()
self._find_vtordisp()
def _load_cvdump(self):
logger.info("Parsing %s ...", self.pdb_file)
self.cv = (
Cvdump(self.pdb_file)
.lines()
.globals()
.publics()
.symbols()
.section_contributions()
.types()
.run()
)
self.cvdump_analysis = CvdumpAnalysis(self.cv)
for sym in self.cvdump_analysis.nodes:
# Skip nodes where we have almost no information.
# These probably came from SECTION CONTRIBUTIONS.
if sym.name() is None and sym.node_type is None:
continue
# The PDB might contain sections that do not line up with the
# actual binary. The symbol "__except_list" is one example.
# In these cases, just skip this symbol and move on because
# we can't do much with it.
if not self.recomp_bin.is_valid_section(sym.section):
continue
addr = self.recomp_bin.get_abs_addr(sym.section, sym.offset)
sym.addr = addr
# If this symbol is the final one in its section, we were not able to
# estimate its size because we didn't have the total size of that section.
# We can get this estimate now and assume that the final symbol occupies
# the remainder of the section.
if sym.estimated_size is None:
sym.estimated_size = (
self.recomp_bin.get_section_extent_by_index(sym.section)
- sym.offset
)
if sym.node_type == SymbolType.STRING:
string_info = demangle_string_const(sym.decorated_name)
if string_info is None:
logger.debug(
"Could not demangle string symbol: %s", sym.decorated_name
)
continue
# TODO: skip unicode for now. will need to handle these differently.
if string_info.is_utf16:
continue
raw = self.recomp_bin.read(addr, sym.size())
try:
# We use the string length reported in the mangled symbol as the
# data size, but this is not always accurate with respect to the
# null terminator.
# e.g. ??_C@_0BA@EFDM@MxObjectFactory?$AA@
# reported length: 16 (includes null terminator)
# c.f. ??_C@_03DPKJ@enz?$AA@
# reported length: 3 (does NOT include terminator)
# This will handle the case where the entire string contains "\x00"
# because those are distinct from the empty string of length 0.
decoded_string = raw.decode("latin1")
rstrip_string = decoded_string.rstrip("\x00")
if decoded_string != "" and rstrip_string != "":
sym.friendly_name = rstrip_string
else:
sym.friendly_name = decoded_string
except UnicodeDecodeError:
pass
self._db.set_recomp_symbol(
addr, sym.node_type, sym.name(), sym.decorated_name, sym.size()
)
for (section, offset), (
filename,
line_no,
) in self.cvdump_analysis.verified_lines.items():
addr = self.recomp_bin.get_abs_addr(section, offset)
self._lines_db.add_line(filename, line_no, addr)
# The _entry symbol is referenced in the PE header so we get this match for free.
self._db.set_function_pair(self.orig_bin.entry, self.recomp_bin.entry)
def _load_markers(self):
# Assume module name is the base filename of the original binary.
(module, _) = os.path.splitext(os.path.basename(self.orig_bin.filename))
codefiles = list(walk_source_dir(self.code_dir))
codebase = DecompCodebase(codefiles, module.upper())
def orig_bin_checker(addr: int) -> bool:
return self.orig_bin.is_valid_vaddr(addr)
# If the address of any annotation would cause an exception,
# remove it and report an error.
bad_annotations = codebase.prune_invalid_addrs(orig_bin_checker)
for sym in bad_annotations:
logger.error(
"Invalid address 0x%x on %s annotation in file: %s",
sym.offset,
sym.type.name,
sym.filename,
)
# Match lineref functions first because this is a guaranteed match.
# If we have two functions that share the same name, and one is
# a lineref, we can match the nameref correctly because the lineref
# was already removed from consideration.
for fun in codebase.iter_line_functions():
recomp_addr = self._lines_db.search_line(fun.filename, fun.line_number)
if recomp_addr is not None:
self._db.set_function_pair(fun.offset, recomp_addr)
if fun.should_skip():
self._db.mark_stub(fun.offset)
for fun in codebase.iter_name_functions():
self._db.match_function(fun.offset, fun.name)
if fun.should_skip():
self._db.mark_stub(fun.offset)
for var in codebase.iter_variables():
if var.is_static and var.parent_function is not None:
self._db.match_static_variable(
var.offset, var.name, var.parent_function
)
else:
if self._db.match_variable(var.offset, var.name):
self._check_if_array_and_match_elements(var.offset, var.name)
for tbl in codebase.iter_vtables():
self._db.match_vtable(tbl.offset, tbl.name, tbl.base_class)
for string in codebase.iter_strings():
# Not that we don't trust you, but we're checking the string
# annotation to make sure it is accurate.
try:
# TODO: would presumably fail for wchar_t strings
orig = self.orig_bin.read_string(string.offset).decode("latin1")
string_correct = string.name == orig
except UnicodeDecodeError:
string_correct = False
if not string_correct:
logger.error(
"Data at 0x%x does not match string %s",
string.offset,
repr(string.name),
)
continue
self._db.match_string(string.offset, string.name)
def _check_if_array_and_match_elements(self, orig_addr: int, name: str):
"""
Checks if the global variable at `orig_addr` is an array.
If yes, adds a match for all its elements. If it is an array of structs, all fields in that struct are also matched.
Note that there is no recursion, so an array of arrays would not be handled entirely.
This step is necessary e.g. for `0x100f0a20` (LegoRacers.cpp).
"""
def _add_match_in_array(
name: str, type_id: str, orig_addr: int, recomp_addr: int
):
self._db.set_recomp_symbol(
recomp_addr,
SymbolType.POINTER if scalar_type_pointer(type_id) else SymbolType.DATA,
name,
name,
# we only need the matches when they are referenced elsewhere, hence we don't need the size
size=None,
)
self._db.set_pair(orig_addr, recomp_addr)
matchinfo = self._db.get_by_orig(orig_addr)
if matchinfo is None or matchinfo.recomp_addr is None:
return
recomp_addr = matchinfo.recomp_addr
node = next(
(x for x in self.cvdump_analysis.nodes if x.addr == recomp_addr),
None,
)
if node is None or node.data_type is None:
return
if not node.data_type.key.startswith("0x"):
# scalar type, so clearly not an array
return
data_type = self.cv.types.keys[node.data_type.key.lower()]
if data_type["type"] == "LF_ARRAY":
array_element_type = self.cv.types.get(data_type["array_type"])
assert node.data_type.members is not None
for array_element in node.data_type.members:
orig_element_base_addr = orig_addr + array_element.offset
recomp_element_base_addr = recomp_addr + array_element.offset
if array_element_type.members is None:
_add_match_in_array(
f"{name}{array_element.name}",
array_element_type.key,
orig_element_base_addr,
recomp_element_base_addr,
)
else:
for member in array_element_type.members:
_add_match_in_array(
f"{name}{array_element.name}.{member.name}",
array_element_type.key,
orig_element_base_addr + member.offset,
recomp_element_base_addr + member.offset,
)
def _find_original_strings(self):
"""Go to the original binary and look for the specified string constants
to find a match. This is a (relatively) expensive operation so we only
look at strings that we have not already matched via a STRING annotation."""
# Release builds give each de-duped string a symbol so they are easy to find and match.
for string in self._db.get_unmatched_strings():
addr = self.orig_bin.find_string(string.encode("latin1"))
if addr is None:
escaped = repr(string)
logger.debug("Failed to find this string in the original: %s", escaped)
continue
self._db.match_string(addr, string)
def is_real_string(s: str) -> bool:
"""Heuristic to ignore values that only look like strings.
This is mostly about short strings (len <= 4) that could be byte or word values.
"""
# 0x10 is the MSB of the address space for DLLs (LEGO1), so this is a pointer
if len(s) == 0 or "\x10" in s:
return False
# assert(0) is common
if len(s) == 1 and s[0] != "0":
return False
# Hack because str.isprintable() will fail on strings with newlines or tabs
if len(s) <= 4 and "\\x" in repr(s):
return False
return True
# Debug builds do not de-dupe the strings, so we need to find them via brute force scan.
# We could try to match the string addrs if there is only one in orig and recomp.
# When we sanitize the asm, the result is the same regardless.
if self.orig_bin.is_debug:
for addr, string in self.orig_bin.iter_string("latin1"):
if is_real_string(string):
self._db.set_orig_symbol(
addr, SymbolType.STRING, string, len(string)
)
for addr, string in self.recomp_bin.iter_string("latin1"):
if is_real_string(string):
self._db.set_recomp_symbol(
addr, SymbolType.STRING, string, None, len(string)
)
def _find_float_const(self):
"""Add floating point constants in each binary to the database.
We are not matching anything right now because these values are not
deduped like strings."""
for addr, size, float_value in self.orig_bin.find_float_consts():
self._db.set_orig_symbol(addr, SymbolType.FLOAT, str(float_value), size)
for addr, size, float_value in self.recomp_bin.find_float_consts():
self._db.set_recomp_symbol(
addr, SymbolType.FLOAT, str(float_value), None, size
)
def _match_imports(self):
"""We can match imported functions based on the DLL name and
function symbol name."""
orig_byaddr = {
addr: (dll.upper(), name) for (dll, name, addr) in self.orig_bin.imports
}
recomp_byname = {
(dll.upper(), name): addr for (dll, name, addr) in self.recomp_bin.imports
}
# Combine these two dictionaries. We don't care about imports from recomp
# not found in orig because:
# 1. They shouldn't be there
# 2. They are already identified via cvdump
orig_to_recomp = {
addr: recomp_byname.get(pair, None) for addr, pair in orig_byaddr.items()
}
# Now: we have the IAT offset in each matched up, so we need to make
# the connection between the thunk functions.
# We already have the symbol name we need from the PDB.
for orig, recomp in orig_to_recomp.items():
if orig is None or recomp is None:
continue
# Match the __imp__ symbol
self._db.set_pair(orig, recomp, SymbolType.POINTER)
# Read the relative address from .idata
try:
(recomp_rva,) = struct.unpack("<L", self.recomp_bin.read(recomp, 4))
(orig_rva,) = struct.unpack("<L", self.orig_bin.read(orig, 4))
except ValueError:
# Bail out if there's a problem with struct.unpack
continue
# Strictly speaking, this is a hack to support asm sanitize.
# When calling an import, we will recognize that the address for the
# CALL instruction is a pointer to the actual address, but this is
# not only not the address of a function, it is not an address at all.
# To make the asm display work correctly (i.e. to match what you see
# in ghidra) create a function match on the RVA. This is not a valid
# virtual address because it is before the imagebase, but it will
# do what we need it to do in the sanitize function.
(dll_name, func_name) = orig_byaddr[orig]
fullname = dll_name + ":" + func_name
self._db.set_recomp_symbol(
recomp_rva, SymbolType.FUNCTION, fullname, None, 4
)
self._db.set_pair(orig_rva, recomp_rva, SymbolType.FUNCTION)
self._db.skip_compare(orig_rva)
def _match_thunks(self):
"""Thunks are (by nature) matched by indirection. If a thunk from orig
points at a function we have already matched, we can find the matching
thunk in recomp because it points to the same place."""
# Mark all recomp thunks first. This allows us to use their name
# when we sanitize the asm.
for recomp_thunk, recomp_addr in self.recomp_bin.thunks:
recomp_func = self._db.get_by_recomp(recomp_addr)
if recomp_func is None:
continue
self._db.create_recomp_thunk(recomp_thunk, recomp_func.name)
# Thunks may be non-unique, so use a list as dict value when
# inverting the list of tuples from self.recomp_bin.
recomp_thunks = {}
for thunk_addr, func_addr in self.recomp_bin.thunks:
recomp_thunks.setdefault(func_addr, []).append(thunk_addr)
# Now match the thunks from orig where we can.
for orig_thunk, orig_addr in self.orig_bin.thunks:
orig_func = self._db.get_by_orig(orig_addr)
if orig_func is None:
continue
# Check whether the thunk destination is a matched symbol
if orig_func.recomp_addr not in recomp_thunks:
self._db.create_orig_thunk(orig_thunk, orig_func.name)
continue
# If there are multiple thunks, they are already in v.addr order.
# Pop the earliest one and match it.
recomp_thunk = recomp_thunks[orig_func.recomp_addr].pop(0)
if len(recomp_thunks[orig_func.recomp_addr]) == 0:
del recomp_thunks[orig_func.recomp_addr]
self._db.set_function_pair(orig_thunk, recomp_thunk)
# Don't compare thunk functions for now. The comparison isn't
# "useful" in the usual sense. We are only looking at the
# bytes of the jmp instruction and not the larger context of
# where this function is. Also: these will always match 100%
# because we are searching for a match to register this as a
# function in the first place.
self._db.skip_compare(orig_thunk)
def _match_exports(self):
# invert for name lookup
orig_exports = {y: x for (x, y) in self.orig_bin.exports}
for recomp_addr, export_name in self.recomp_bin.exports:
orig_addr = orig_exports.get(export_name)
if orig_addr is None:
continue
try:
# Check whether either of the addresses is actually a thunk.
# This is a quirk of the debug builds. Technically the export
# *is* the thunk, but it's more helpful to mark the actual function.
# It could be the case that only one side is a thunk, but we can
# deal with that.
(opcode, rel_addr) = struct.unpack(
"<Bl", self.recomp_bin.read(recomp_addr, 5)
)
if opcode == 0xE9:
recomp_addr += 5 + rel_addr
(opcode, rel_addr) = struct.unpack(
"<Bl", self.orig_bin.read(orig_addr, 5)
)
if opcode == 0xE9:
orig_addr += 5 + rel_addr
except ValueError:
# Bail out if there's a problem with struct.unpack
continue
if self._db.set_pair_tentative(orig_addr, recomp_addr):
logger.debug("Matched export %s", repr(export_name))
def _find_vtordisp(self):
"""If there are any cases of virtual inheritance, we can read
through the vtables for those classes and find the vtable thunk
functions (vtordisp).
Our approach is this: walk both vtables and check where we have a
vtordisp in the recomp table. Inspect the function at that vtable
position (in both) and check whether we jump to the same function.
One potential pitfall here is that the virtual displacement could
differ between the thunks. We are not (yet) checking for this, so the
result is that the vtable will appear to match but we will have a diff
on the thunk in our regular function comparison.
We could do this differently and check only the original vtable,
construct the name of the vtordisp function and match based on that."""
for match in self._db.get_matches_by_type(SymbolType.VTABLE):
assert (
match.name is not None
and match.orig_addr is not None
and match.recomp_addr is not None
and match.size is not None
)
# We need some method of identifying vtables that
# might have thunks, and this ought to work okay.
if "{for" not in match.name:
continue
next_orig = self._db.get_next_orig_addr(match.orig_addr)
assert next_orig is not None
orig_upper_size_limit = next_orig - match.orig_addr
if orig_upper_size_limit < match.size:
# This could happen in debug builds due to code changes between BETA10 and LEGO1,
# but we have not seen it yet as of 2024-08-28.
logger.warning(
"Recomp vtable is larger than orig vtable for %s",
match.name,
)
# TODO: We might want to fix this at the source (cvdump) instead.
# Any problem will be logged later when we compare the vtable.
vtable_size = 4 * (min(match.size, orig_upper_size_limit) // 4)
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
raw_addrs = zip(
[t for (t,) in struct.iter_unpack("<L", orig_table)],
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
)
# Now walk both vtables looking for thunks.
for orig_addr, recomp_addr in raw_addrs:
if orig_addr == 0:
# This happens in debug builds due to code changes between BETA10 and LEGO1.
# Note that there is a risk of running into the next vtable if there is no gap in between,
# which we cannot protect against at the moment.
logger.warning(
"Recomp vtable is larger than orig vtable for %s", match.name
)
break
if self._db.is_vtordisp(recomp_addr):
self._match_vtordisp_in_vtable(orig_addr, recomp_addr)
def _match_vtordisp_in_vtable(self, orig_addr, recomp_addr):
thunk_fn = self.get_by_recomp(recomp_addr)
assert thunk_fn is not None
assert thunk_fn.size is not None
# Read the function bytes here.
# In practice, the adjuster thunk will be under 16 bytes.
# If we have thunks of unequal size, we can still tell whether they are thunking
# the same function by grabbing the JMP instruction at the end.
thunk_presumed_size = max(thunk_fn.size, 16)
# Strip off MSVC padding 0xcc bytes.
# This should be safe to do; it is highly unlikely that
# the MSB of the jump displacement would be 0xcc. (huge jump)
orig_thunk_bin = self.orig_bin.read(orig_addr, thunk_presumed_size).rstrip(
b"\xcc"
)
recomp_thunk_bin = self.recomp_bin.read(
recomp_addr, thunk_presumed_size
).rstrip(b"\xcc")
# Read jump opcode and displacement (last 5 bytes)
(orig_jmp, orig_disp) = struct.unpack("<Bi", orig_thunk_bin[-5:])
(recomp_jmp, recomp_disp) = struct.unpack("<Bi", recomp_thunk_bin[-5:])
# Make sure it's a JMP
if orig_jmp != 0xE9 or recomp_jmp != 0xE9:
logger.warning(
"Not a jump in vtordisp at (0x%x, 0x%x)", orig_addr, recomp_addr
)
return
# Calculate jump destination from the end of the JMP instruction
# i.e. the end of the function
orig_actual = orig_addr + len(orig_thunk_bin) + orig_disp
recomp_actual = recomp_addr + len(recomp_thunk_bin) + recomp_disp
# If they are thunking the same function, then this must be a match.
if self.is_pointer_match(orig_actual, recomp_actual):
if len(orig_thunk_bin) != len(recomp_thunk_bin):
logger.warning(
"Adjuster thunk %s (0x%x) is not exact",
thunk_fn.name,
orig_addr,
)
self._db.set_function_pair(orig_addr, recomp_addr)
def _dump_asm(self, orig_combined, recomp_combined):
"""Append the provided assembly output to the debug files"""
with open(f"orig-{self.runid}.txt", "a", encoding="utf-8") as f:
for addr, line in orig_combined:
f.write(f"{addr}: {line}\n")
with open(f"recomp-{self.runid}.txt", "a", encoding="utf-8") as f:
for addr, line in recomp_combined:
f.write(f"{addr}: {line}\n")
def _compare_function(self, match: MatchInfo) -> DiffReport:
# Detect when the recomp function size would cause us to read
# enough bytes from the original function that we cross into
# the next annotated function.
next_orig = self._db.get_next_orig_addr(match.orig_addr)
if next_orig is not None:
orig_size = min(next_orig - match.orig_addr, match.size)
else:
orig_size = match.size
orig_raw = self.orig_bin.read(match.orig_addr, orig_size)
recomp_raw = self.recomp_bin.read(match.recomp_addr, match.size)
# It's unlikely that a function other than an adjuster thunk would
# start with a SUB instruction, so alert to a possible wrong
# annotation here.
# There's probably a better place to do this, but we're reading
# the function bytes here already.
try:
if orig_raw[0] == 0x2B and recomp_raw[0] != 0x2B:
logger.warning(
"Possible thunk at 0x%x (%s)", match.orig_addr, match.name
)
except IndexError:
pass
def orig_lookup(addr: int, exact: bool) -> Optional[str]:
m = self._db.get_by_orig(addr, exact)
if m is None:
return None
if m.orig_addr == addr:
return m.match_name()
offset = addr - m.orig_addr
if m.compare_type != SymbolType.DATA or offset >= m.size:
return None
return m.offset_name(offset)
def recomp_lookup(addr: int, exact: bool) -> Optional[str]:
m = self._db.get_by_recomp(addr, exact)
if m is None:
return None
if m.recomp_addr == addr:
return m.match_name()
offset = addr - m.recomp_addr
if m.compare_type != SymbolType.DATA or offset >= m.size:
return None
return m.offset_name(offset)
orig_should_replace = create_reloc_lookup(self.orig_bin)
recomp_should_replace = create_reloc_lookup(self.recomp_bin)
orig_bin_lookup = create_bin_lookup(self.orig_bin)
recomp_bin_lookup = create_bin_lookup(self.recomp_bin)
orig_parse = ParseAsm(
relocate_lookup=orig_should_replace,
name_lookup=orig_lookup,
bin_lookup=orig_bin_lookup,
)
recomp_parse = ParseAsm(
relocate_lookup=recomp_should_replace,
name_lookup=recomp_lookup,
bin_lookup=recomp_bin_lookup,
)
orig_combined = orig_parse.parse_asm(orig_raw, match.orig_addr)
recomp_combined = recomp_parse.parse_asm(recomp_raw, match.recomp_addr)
if self.debug:
self._dump_asm(orig_combined, recomp_combined)
# Check for assert calls only if we expect to find them
if self.orig_bin.is_debug or self.recomp_bin.is_debug:
assert_fixup(orig_combined)
assert_fixup(recomp_combined)
# Detach addresses from asm lines for the text diff.
orig_asm = [x[1] for x in orig_combined]
recomp_asm = [x[1] for x in recomp_combined]
diff = difflib.SequenceMatcher(None, orig_asm, recomp_asm, autojunk=False)
ratio = diff.ratio()
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
codes = diff.get_opcodes()
is_effective_match = find_effective_match(codes, orig_asm, recomp_asm)
unified_diff = combined_diff(
diff, orig_combined, recomp_combined, context_size=10
)
else:
is_effective_match = False
unified_diff = []
return DiffReport(
match_type=SymbolType.FUNCTION,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
udiff=unified_diff,
ratio=ratio,
is_effective_match=is_effective_match,
)
def _compare_vtable(self, match: MatchInfo) -> DiffReport:
vtable_size = match.size
# The vtable size should always be a multiple of 4 because that
# is the pointer size. If it is not (for whatever reason)
# it would cause iter_unpack to blow up so let's just fix it.
if vtable_size % 4 != 0:
logger.warning(
"Vtable for class %s has irregular size %d", match.name, vtable_size
)
vtable_size = 4 * (vtable_size // 4)
orig_table = self.orig_bin.read(match.orig_addr, vtable_size)
recomp_table = self.recomp_bin.read(match.recomp_addr, vtable_size)
raw_addrs = zip(
[t for (t,) in struct.iter_unpack("<L", orig_table)],
[t for (t,) in struct.iter_unpack("<L", recomp_table)],
)
def match_text(m: Optional[MatchInfo], raw_addr: Optional[int] = None) -> str:
"""Format the function reference at this vtable index as text.
If we have not identified this function, we have the option to
display the raw address. This is only worth doing for the original addr
because we should always be able to identify the recomp function.
If the original function is missing then this probably means that the class
should override the given function from the superclass, but we have not
implemented this yet.
"""
if m is not None:
orig = hex(m.orig_addr) if m.orig_addr is not None else "no orig"
recomp = (
hex(m.recomp_addr) if m.recomp_addr is not None else "no recomp"
)
return f"({orig} / {recomp}) : {m.name}"
if raw_addr is not None:
return f"0x{raw_addr:x} from orig not annotated."
return "(no match)"
orig_text = []
recomp_text = []
ratio = 0
n_entries = 0
# Now compare each pointer from the two vtables.
for i, (raw_orig, raw_recomp) in enumerate(raw_addrs):
orig = self._db.get_by_orig(raw_orig)
recomp = self._db.get_by_recomp(raw_recomp)
if (
orig is not None
and recomp is not None
and orig.recomp_addr == recomp.recomp_addr
):
ratio += 1
n_entries += 1
index = f"vtable0x{i*4:02x}"
orig_text.append((index, match_text(orig, raw_orig)))
recomp_text.append((index, match_text(recomp)))
ratio = ratio / float(n_entries) if n_entries > 0 else 0
# n=100: Show the entire table if there is a diff to display.
# Otherwise it would be confusing if the table got cut off.
sm = difflib.SequenceMatcher(
None,
[x[1] for x in orig_text],
[x[1] for x in recomp_text],
)
unified_diff = combined_diff(sm, orig_text, recomp_text, context_size=100)
return DiffReport(
match_type=SymbolType.VTABLE,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
udiff=unified_diff,
ratio=ratio,
)
def _compare_match(self, match: MatchInfo) -> Optional[DiffReport]:
"""Router for comparison type"""
if match.size is None or match.size == 0:
return None
options = self._db.get_match_options(match.orig_addr)
if options.get("skip", False):
return None
if options.get("stub", False):
return DiffReport(
match_type=match.compare_type,
orig_addr=match.orig_addr,
recomp_addr=match.recomp_addr,
name=match.name,
is_stub=True,
)
if match.compare_type == SymbolType.FUNCTION:
return self._compare_function(match)
if match.compare_type == SymbolType.VTABLE:
return self._compare_vtable(match)
return None
## Public API
def is_pointer_match(self, orig_addr, recomp_addr) -> bool:
"""Check whether these pointers point at the same thing"""
# Null pointers considered matching
if orig_addr == 0 and recomp_addr == 0:
return True
match = self._db.get_by_orig(orig_addr)
if match is None:
return False
return match.recomp_addr == recomp_addr
def get_by_orig(self, addr: int) -> Optional[MatchInfo]:
return self._db.get_by_orig(addr)
def get_by_recomp(self, addr: int) -> Optional[MatchInfo]:
return self._db.get_by_recomp(addr)
def get_all(self) -> List[MatchInfo]:
return self._db.get_all()
def get_functions(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.FUNCTION)
def get_vtables(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.VTABLE)
def get_variables(self) -> List[MatchInfo]:
return self._db.get_matches_by_type(SymbolType.DATA)
def get_match_options(self, addr: int) -> Optional[dict[str, Any]]:
return self._db.get_match_options(addr)
def compare_address(self, addr: int) -> Optional[DiffReport]:
match = self._db.get_one_match(addr)
if match is None:
return None
return self._compare_match(match)
def compare_all(self) -> Iterable[DiffReport]:
for match in self._db.get_matches():
diff = self._compare_match(match)
if diff is not None:
yield diff
def compare_functions(self) -> Iterable[DiffReport]:
for match in self.get_functions():
diff = self._compare_match(match)
if diff is not None:
yield diff
def compare_variables(self):
pass
def compare_pointers(self):
pass
def compare_strings(self):
pass
def compare_vtables(self) -> Iterable[DiffReport]:
for match in self.get_vtables():
diff = self._compare_match(match)
if diff is not None:
yield self._compare_match(match)

View file

@ -1,554 +0,0 @@
"""Wrapper for database (here an in-memory sqlite database) that collects the
addresses/symbols that we want to compare between the original and recompiled binaries."""
import sqlite3
import logging
from typing import Any, List, Optional
from isledecomp.types import SymbolType
from isledecomp.cvdump.demangler import get_vtordisp_name
_SETUP_SQL = """
DROP TABLE IF EXISTS `symbols`;
DROP TABLE IF EXISTS `match_options`;
CREATE TABLE `symbols` (
compare_type int,
orig_addr int,
recomp_addr int,
name text,
decorated_name text,
size int
);
CREATE TABLE `match_options` (
addr int not null,
name text not null,
value text,
primary key (addr, name)
) without rowid;
CREATE VIEW IF NOT EXISTS `match_info`
(compare_type, orig_addr, recomp_addr, name, size) AS
SELECT compare_type, orig_addr, recomp_addr, name, size
FROM `symbols`
ORDER BY orig_addr NULLS LAST;
CREATE INDEX `symbols_or` ON `symbols` (orig_addr);
CREATE INDEX `symbols_re` ON `symbols` (recomp_addr);
CREATE INDEX `symbols_na` ON `symbols` (name);
"""
class MatchInfo:
def __init__(
self,
ctype: Optional[int],
orig: Optional[int],
recomp: Optional[int],
name: Optional[str],
size: Optional[int],
) -> None:
self.compare_type = SymbolType(ctype) if ctype is not None else None
self.orig_addr = orig
self.recomp_addr = recomp
self.name = name
self.size = size
def match_name(self) -> Optional[str]:
"""Combination of the name and compare type.
Intended for name substitution in the diff. If there is a diff,
it will be more obvious what this symbol indicates."""
if self.name is None:
return None
ctype = self.compare_type.name if self.compare_type is not None else "UNK"
name = repr(self.name) if ctype == "STRING" else self.name
return f"{name} ({ctype})"
def offset_name(self, ofs: int) -> Optional[str]:
if self.name is None:
return None
return f"{self.name}+{ofs} (OFFSET)"
def matchinfo_factory(_, row):
return MatchInfo(*row)
logger = logging.getLogger(__name__)
class CompareDb:
# pylint: disable=too-many-public-methods
def __init__(self):
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
def set_orig_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
size: Optional[int],
):
# Ignore collisions here.
if self._orig_used(addr):
return
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (orig_addr, compare_type, name, size) VALUES (?,?,?,?)",
(addr, compare_value, name, size),
)
def set_recomp_symbol(
self,
addr: int,
compare_type: Optional[SymbolType],
name: Optional[str],
decorated_name: Optional[str],
size: Optional[int],
):
# Ignore collisions here. The same recomp address can have
# multiple names (e.g. _strlwr and __strlwr)
if self._recomp_used(addr):
return
compare_value = compare_type.value if compare_type is not None else None
self._db.execute(
"INSERT INTO `symbols` (recomp_addr, compare_type, name, decorated_name, size) VALUES (?,?,?,?,?)",
(addr, compare_value, name, decorated_name, size),
)
def get_unmatched_strings(self) -> List[str]:
"""Return any strings not already identified by STRING markers."""
cur = self._db.execute(
"SELECT name FROM `symbols` WHERE compare_type = ? AND orig_addr IS NULL",
(SymbolType.STRING.value,),
)
return [string for (string,) in cur.fetchall()]
def get_all(self) -> List[MatchInfo]:
cur = self._db.execute("SELECT * FROM `match_info`")
cur.row_factory = matchinfo_factory
return cur.fetchall()
def get_matches(self) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT * FROM `match_info`
WHERE orig_addr IS NOT NULL
AND recomp_addr IS NOT NULL
""",
)
cur.row_factory = matchinfo_factory
return cur.fetchall()
def get_one_match(self, addr: int) -> Optional[MatchInfo]:
cur = self._db.execute(
"""SELECT * FROM `match_info`
WHERE orig_addr = ?
AND recomp_addr IS NOT NULL
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def _get_closest_orig(self, addr: int) -> Optional[int]:
value = self._db.execute(
"""SELECT max(orig_addr) FROM `symbols`
WHERE ? >= orig_addr
LIMIT 1
""",
(addr,),
).fetchone()
return value[0] if value is not None else None
def _get_closest_recomp(self, addr: int) -> Optional[int]:
value = self._db.execute(
"""SELECT max(recomp_addr) FROM `symbols`
WHERE ? >= recomp_addr
LIMIT 1
""",
(addr,),
).fetchone()
return value[0] if value is not None else None
def get_by_orig(self, addr: int, exact: bool = True) -> Optional[MatchInfo]:
if not exact and not self._orig_used(addr):
addr = self._get_closest_orig(addr)
if addr is None:
return None
cur = self._db.execute(
"""SELECT * FROM `match_info`
WHERE orig_addr = ?
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_by_recomp(self, addr: int, exact: bool = True) -> Optional[MatchInfo]:
if not exact and not self._recomp_used(addr):
addr = self._get_closest_recomp(addr)
if addr is None:
return None
cur = self._db.execute(
"""SELECT * FROM `match_info`
WHERE recomp_addr = ?
""",
(addr,),
)
cur.row_factory = matchinfo_factory
return cur.fetchone()
def get_matches_by_type(self, compare_type: SymbolType) -> List[MatchInfo]:
cur = self._db.execute(
"""SELECT * FROM `match_info`
WHERE compare_type = ?
AND orig_addr IS NOT NULL
AND recomp_addr IS NOT NULL
""",
(compare_type.value,),
)
cur.row_factory = matchinfo_factory
return cur.fetchall()
def _orig_used(self, addr: int) -> bool:
cur = self._db.execute("SELECT 1 FROM symbols WHERE orig_addr = ?", (addr,))
return cur.fetchone() is not None
def _recomp_used(self, addr: int) -> bool:
cur = self._db.execute("SELECT 1 FROM symbols WHERE recomp_addr = ?", (addr,))
return cur.fetchone() is not None
def set_pair(
self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None
) -> bool:
if self._orig_used(orig):
logger.debug("Original address %s not unique!", hex(orig))
return False
compare_value = compare_type.value if compare_type is not None else None
cur = self._db.execute(
"UPDATE `symbols` SET orig_addr = ?, compare_type = ? WHERE recomp_addr = ?",
(orig, compare_value, recomp),
)
return cur.rowcount > 0
def set_pair_tentative(
self, orig: int, recomp: int, compare_type: Optional[SymbolType] = None
) -> bool:
"""Declare a match for the original and recomp addresses given, but only if:
1. The original address is not used elsewhere (as with set_pair)
2. The recomp address has not already been matched
If the compare_type is given, update this also, but only if NULL in the db.
The purpose here is to set matches found via some automated analysis
but to not overwrite a match provided by the human operator."""
if self._orig_used(orig):
# Probable and expected situation. Just ignore it.
return False
compare_value = compare_type.value if compare_type is not None else None
cur = self._db.execute(
"""UPDATE `symbols`
SET orig_addr = ?, compare_type = coalesce(compare_type, ?)
WHERE recomp_addr = ?
AND orig_addr IS NULL""",
(orig, compare_value, recomp),
)
return cur.rowcount > 0
def set_function_pair(self, orig: int, recomp: int) -> bool:
"""For lineref match or _entry"""
return self.set_pair(orig, recomp, SymbolType.FUNCTION)
def create_orig_thunk(self, addr: int, name: str) -> bool:
"""Create a thunk function reference using the orig address.
We are here because we have a match on the thunked function,
but it is not thunked in the recomp build."""
if self._orig_used(addr):
return False
thunk_name = f"Thunk of '{name}'"
# Assuming relative jump instruction for thunks (5 bytes)
cur = self._db.execute(
"""INSERT INTO `symbols`
(orig_addr, compare_type, name, size)
VALUES (?,?,?,?)""",
(addr, SymbolType.FUNCTION.value, thunk_name, 5),
)
return cur.rowcount > 0
def create_recomp_thunk(self, addr: int, name: str) -> bool:
"""Create a thunk function reference using the recomp address.
We start from the recomp side for this because we are guaranteed
to have full information from the PDB. We can use a regular function
match later to pull in the orig address."""
if self._recomp_used(addr):
return False
thunk_name = f"Thunk of '{name}'"
# Assuming relative jump instruction for thunks (5 bytes)
cur = self._db.execute(
"""INSERT INTO `symbols`
(recomp_addr, compare_type, name, size)
VALUES (?,?,?,?)""",
(addr, SymbolType.FUNCTION.value, thunk_name, 5),
)
return cur.rowcount > 0
def _set_opt_bool(self, addr: int, option: str, enabled: bool = True):
if enabled:
self._db.execute(
"""INSERT OR IGNORE INTO `match_options`
(addr, name)
VALUES (?, ?)""",
(addr, option),
)
else:
self._db.execute(
"""DELETE FROM `match_options` WHERE addr = ? AND name = ?""",
(addr, option),
)
def mark_stub(self, orig: int):
self._set_opt_bool(orig, "stub")
def skip_compare(self, orig: int):
self._set_opt_bool(orig, "skip")
def get_match_options(self, addr: int) -> Optional[dict[str, Any]]:
cur = self._db.execute(
"""SELECT name, value FROM `match_options` WHERE addr = ?""", (addr,)
)
return {
option: value if value is not None else True
for (option, value) in cur.fetchall()
}
def is_vtordisp(self, recomp_addr: int) -> bool:
"""Check whether this function is a vtordisp based on its
decorated name. If its demangled name is missing the vtordisp
indicator, correct that."""
row = self._db.execute(
"""SELECT name, decorated_name
FROM `symbols`
WHERE recomp_addr = ?""",
(recomp_addr,),
).fetchone()
if row is None:
return False
(name, decorated_name) = row
if "`vtordisp" in name:
return True
if decorated_name is None:
# happens in debug builds, e.g. for "Thunk of 'LegoAnimActor::ClassName'"
return False
new_name = get_vtordisp_name(decorated_name)
if new_name is None:
return False
self._db.execute(
"""UPDATE `symbols`
SET name = ?
WHERE recomp_addr = ?""",
(new_name, recomp_addr),
)
return True
def _find_potential_match(
self, name: str, compare_type: SymbolType
) -> Optional[int]:
"""Name lookup"""
match_decorate = compare_type != SymbolType.STRING and name.startswith("?")
if match_decorate:
sql = """
SELECT recomp_addr
FROM `symbols`
WHERE orig_addr IS NULL
AND decorated_name = ?
AND (compare_type IS NULL OR compare_type = ?)
LIMIT 1
"""
else:
sql = """
SELECT recomp_addr
FROM `symbols`
WHERE orig_addr IS NULL
AND name = ?
AND (compare_type IS NULL OR compare_type = ?)
LIMIT 1
"""
row = self._db.execute(sql, (name, compare_type.value)).fetchone()
return row[0] if row is not None else None
def _find_static_variable(
self, variable_name: str, function_sym: str
) -> Optional[int]:
"""Get the recomp address of a static function variable.
Matches using a LIKE clause on the combination of:
1. The variable name read from decomp marker.
2. The decorated name of the enclosing function.
For example, the variable "g_startupDelay" from function "IsleApp::Tick"
has symbol: `?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA`
The function's decorated name is: `?Tick@IsleApp@@QAEXH@Z`"""
row = self._db.execute(
"""SELECT recomp_addr FROM `symbols`
WHERE decorated_name LIKE '%' || ? || '%' || ? || '%'
AND orig_addr IS NULL
AND (compare_type = ? OR compare_type = ? OR compare_type IS NULL)""",
(
variable_name,
function_sym,
SymbolType.DATA.value,
SymbolType.POINTER.value,
),
).fetchone()
return row[0] if row is not None else None
def _match_on(self, compare_type: SymbolType, addr: int, name: str) -> bool:
# Update the compare_type here too since the marker tells us what we should do
# Truncate the name to 255 characters. It will not be possible to match a name
# longer than that because MSVC truncates the debug symbols to this length.
# See also: warning C4786.
name = name[:255]
logger.debug("Looking for %s %s", compare_type.name.lower(), name)
recomp_addr = self._find_potential_match(name, compare_type)
if recomp_addr is None:
return False
return self.set_pair(addr, recomp_addr, compare_type)
def get_next_orig_addr(self, addr: int) -> Optional[int]:
"""Return the original address (matched or not) that follows
the one given. If our recomp function size would cause us to read
too many bytes for the original function, we can adjust it."""
result = self._db.execute(
"""SELECT orig_addr
FROM `symbols`
WHERE orig_addr > ?
ORDER BY orig_addr
LIMIT 1""",
(addr,),
).fetchone()
return result[0] if result is not None else None
def match_function(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.FUNCTION, addr, name)
if not did_match:
logger.error("Failed to find function symbol with name: %s", name)
return did_match
def match_vtable(
self, addr: int, name: str, base_class: Optional[str] = None
) -> bool:
# Set up our potential match names
bare_vftable = f"{name}::`vftable'"
for_name = base_class if base_class is not None else name
for_vftable = f"{name}::`vftable'{{for `{for_name}'}}"
# Only allow a match against "Class:`vftable'"
# if this is the derived class.
if base_class is None or base_class == name:
name_options = (for_vftable, bare_vftable)
else:
name_options = (for_vftable, for_vftable)
row = self._db.execute(
"""
SELECT recomp_addr
FROM `symbols`
WHERE orig_addr IS NULL
AND (name = ? OR name = ?)
AND (compare_type = ?)
LIMIT 1
""",
(*name_options, SymbolType.VTABLE.value),
).fetchone()
if row is not None and self.set_pair(addr, row[0], SymbolType.VTABLE):
return True
logger.error("Failed to find vtable for class: %s", name)
return False
def match_static_variable(self, addr: int, name: str, function_addr: int) -> bool:
"""Matching a static function variable by combining the variable name
with the decorated (mangled) name of its parent function."""
cur = self._db.execute(
"""SELECT name, decorated_name
FROM `symbols`
WHERE orig_addr = ?""",
(function_addr,),
)
if (result := cur.fetchone()) is None:
logger.error("No function for static variable: %s", name)
return False
# Get the friendly name for the "failed to match" error message
(function_name, decorated_name) = result
recomp_addr = self._find_static_variable(name, decorated_name)
if recomp_addr is not None:
# TODO: This variable could be a pointer, but I don't think we
# have a way to tell that right now.
if self.set_pair(addr, recomp_addr, SymbolType.DATA):
return True
logger.error(
"Failed to match static variable %s from function %s",
name,
function_name,
)
return False
def match_variable(self, addr: int, name: str) -> bool:
did_match = self._match_on(SymbolType.DATA, addr, name) or self._match_on(
SymbolType.POINTER, addr, name
)
if not did_match:
logger.error("Failed to find variable: %s", name)
return did_match
def match_string(self, addr: int, value: str) -> bool:
did_match = self._match_on(SymbolType.STRING, addr, value)
if not did_match:
escaped = repr(value)
logger.error("Failed to find string: %s", escaped)
return did_match

View file

@ -1,104 +0,0 @@
from difflib import SequenceMatcher
from typing import Dict, List, Tuple
CombinedDiffInput = List[Tuple[str, str]]
# from inner to outer:
# Tuple[str, ...]: either (orig_addr, instruction, recomp_addr) or (addr, instruction)
# List[...]: a contiguous block of instructions, all matching or all mismatching
# Dict[...]: either {"both": List[...]} or {"orig": [...], "recomp": [...]}
# Tuple[str, List[...]]: One contiguous part of the diff (without skipping matching code)
# List[...]: The list of all the contiguous diffs of a given function
CombinedDiffOutput = List[Tuple[str, List[Dict[str, List[Tuple[str, ...]]]]]]
def combined_diff(
diff: SequenceMatcher,
orig_combined: CombinedDiffInput,
recomp_combined: CombinedDiffInput,
context_size: int = 3,
) -> CombinedDiffOutput:
"""We want to diff the original and recomp assembly. The "combined" assembly
input has two components: the address of the instruction and the assembly text.
We have already diffed the text only. This is the SequenceMatcher object.
The SequenceMatcher can generate "opcodes" that describe how to turn "Text A"
into "Text B". These refer to list indices of the original arrays, so we can
use those to create the final diff and include the address for each line of assembly.
This is almost the same procedure as the difflib.unified_diff function, but we
are reusing the already generated SequenceMatcher object.
"""
unified_diff = []
for group in diff.get_grouped_opcodes(context_size):
subgroups = []
# Keep track of the addresses we've seen in this diff group.
# This helps create the "@@" line. (Does this have a name?)
# Do it this way because not every line in each list will have an
# address. If our context begins or ends on a line that does not
# have one, we will have an incomplete range string.
orig_addrs = set()
recomp_addrs = set()
first, last = group[0], group[-1]
orig_range = len(orig_combined[first[1] : last[2]])
recomp_range = len(recomp_combined[first[3] : last[4]])
for code, i1, i2, j1, j2 in group:
if code == "equal":
# The sections are equal, so the list slices are guaranteed
# to have the same length. We only need the diffed value (asm text)
# from one of the lists, but we need the addresses from both.
# Use zip to put the two lists together and then take out what we want.
both = [
(a, b, c)
for ((a, b), (c, _)) in zip(
orig_combined[i1:i2], recomp_combined[j1:j2]
)
]
for orig_addr, _, recomp_addr in both:
if orig_addr is not None:
orig_addrs.add(orig_addr)
if recomp_addr is not None:
recomp_addrs.add(recomp_addr)
subgroups.append({"both": both})
else:
for orig_addr, _ in orig_combined[i1:i2]:
if orig_addr is not None:
orig_addrs.add(orig_addr)
for recomp_addr, _ in recomp_combined[j1:j2]:
if recomp_addr is not None:
recomp_addrs.add(recomp_addr)
subgroups.append(
{
"orig": orig_combined[i1:i2],
"recomp": recomp_combined[j1:j2],
}
)
orig_sorted = sorted(orig_addrs)
recomp_sorted = sorted(recomp_addrs)
# We could get a diff group that has no original addresses.
# This might happen for a stub function where we are not able to
# produce even a single instruction from the original.
# In that case, show the best slug line that we can.
def peek_front(list_, default=""):
try:
return list_[0]
except IndexError:
return default
orig_first = peek_front(orig_sorted)
recomp_first = peek_front(recomp_sorted)
diff_slug = f"@@ -{orig_first},{orig_range} +{recomp_first},{recomp_range} @@"
unified_diff.append((diff_slug, subgroups))
return unified_diff

View file

@ -1,69 +0,0 @@
"""Database used to match (filename, line_number) pairs
between FUNCTION markers and PDB analysis."""
import sqlite3
import logging
from functools import cache
from typing import Optional
from pathlib import Path
from isledecomp.dir import PathResolver
_SETUP_SQL = """
DROP TABLE IF EXISTS `lineref`;
CREATE TABLE `lineref` (
path text not null,
filename text not null,
line int not null,
addr int not null
);
CREATE INDEX `file_line` ON `lineref` (filename, line);
"""
logger = logging.getLogger(__name__)
@cache
def my_samefile(path: str, source_path: str) -> bool:
return Path(path).samefile(source_path)
@cache
def my_basename_lower(path: str) -> str:
return Path(path).name.lower()
class LinesDb:
def __init__(self, code_dir) -> None:
self._db = sqlite3.connect(":memory:")
self._db.executescript(_SETUP_SQL)
self._path_resolver = PathResolver(code_dir)
def add_line(self, path: str, line_no: int, addr: int):
"""To be added from the LINES section of cvdump."""
sourcepath = self._path_resolver.resolve_cvdump(path)
filename = my_basename_lower(sourcepath)
self._db.execute(
"INSERT INTO `lineref` (path, filename, line, addr) VALUES (?,?,?,?)",
(sourcepath, filename, line_no, addr),
)
def search_line(self, path: str, line_no: int) -> Optional[int]:
"""Using path and line number from FUNCTION marker,
get the address of this function in the recomp."""
filename = my_basename_lower(path)
cur = self._db.execute(
"SELECT path, addr FROM `lineref` WHERE filename = ? AND line = ?",
(filename, line_no),
)
for source_path, addr in cur.fetchall():
if my_samefile(path, source_path):
return addr
logger.error(
"Failed to find function symbol with filename and line: %s:%d",
path,
line_no,
)
return None

View file

@ -1,5 +0,0 @@
from .symbols import SymbolsEntry
from .analysis import CvdumpAnalysis
from .parser import CvdumpParser
from .runner import Cvdump
from .types import CvdumpTypesParser

View file

@ -1,187 +0,0 @@
"""For collating the results from parsing cvdump.exe into a more directly useful format."""
from typing import Dict, List, Tuple, Optional
from isledecomp.cvdump import SymbolsEntry
from isledecomp.types import SymbolType
from .parser import CvdumpParser
from .demangler import demangle_string_const, demangle_vtable
from .types import CvdumpKeyError, CvdumpIntegrityError, TypeInfo
class CvdumpNode:
# pylint: disable=too-many-instance-attributes
# These two are required and allow us to identify the symbol
section: int
offset: int
# aka the mangled name from the PUBLICS section
decorated_name: Optional[str] = None
# optional "nicer" name (e.g. of a function from SYMBOLS section)
friendly_name: Optional[str] = None
# To be determined by context after inserting data, unless the decorated
# name makes this obvious. (i.e. string constants or vtables)
# We choose not to assume that section 1 (probably ".text") contains only
# functions. Smacker functions are linked to their own section "_UNSTEXT"
node_type: Optional[SymbolType] = None
# Function size can be read from the LINES section so use this over any
# other value if we have it.
# TYPES section can tell us the size of structs and other complex types.
confirmed_size: Optional[int] = None
# Estimated by reading the distance between this symbol and the one that
# follows in the same section.
# If this is the last symbol in the section, we cannot estimate a size.
estimated_size: Optional[int] = None
# Size as reported by SECTION CONTRIBUTIONS section. Not guaranteed to be
# accurate.
section_contribution: Optional[int] = None
addr: Optional[int] = None
symbol_entry: Optional[SymbolsEntry] = None
# Preliminary - only used for non-static variables at the moment
data_type: Optional[TypeInfo] = None
def __init__(self, section: int, offset: int) -> None:
self.section = section
self.offset = offset
def set_decorated(self, name: str):
self.decorated_name = name
if self.decorated_name.startswith("??_7"):
self.node_type = SymbolType.VTABLE
self.friendly_name = demangle_vtable(self.decorated_name)
elif self.decorated_name.startswith("??_8"):
# This is the `vbtable' symbol for virtual inheritance.
# Should be okay to reuse demangle_vtable. We still want to
# remove things like "const" from the output.
self.node_type = SymbolType.DATA
self.friendly_name = demangle_vtable(self.decorated_name)
elif self.decorated_name.startswith("??_C@"):
self.node_type = SymbolType.STRING
(strlen, _) = demangle_string_const(self.decorated_name)
self.confirmed_size = strlen
elif not self.decorated_name.startswith("?") and "@" in self.decorated_name:
# C mangled symbol. The trailing at-sign with number tells the number of bytes
# in the parameter list for __stdcall, __fastcall, or __vectorcall
# For __cdecl it is more ambiguous and we would have to know which section we are in.
# https://learn.microsoft.com/en-us/cpp/build/reference/decorated-names?view=msvc-170#FormatC
self.node_type = SymbolType.FUNCTION
def name(self) -> Optional[str]:
"""Prefer "friendly" name if we have it.
This is what we have been using to match functions."""
return (
self.friendly_name
if self.friendly_name is not None
else self.decorated_name
)
def size(self) -> Optional[int]:
if self.confirmed_size is not None:
return self.confirmed_size
# Better to undershoot the size because we can identify a comparison gap easily
if self.estimated_size is not None and self.section_contribution is not None:
return min(self.estimated_size, self.section_contribution)
# Return whichever one we have, or neither
return self.estimated_size or self.section_contribution
class CvdumpAnalysis:
"""Collects the results from CvdumpParser into a list of nodes (i.e. symbols).
These can then be analyzed by a downstream tool."""
verified_lines: Dict[Tuple[str, str], Tuple[str, str]]
def __init__(self, parser: CvdumpParser):
"""Read in as much information as we have from the parser.
The more sections we have, the better our information will be."""
node_dict: Dict[Tuple[int, int], CvdumpNode] = {}
# PUBLICS is our roadmap for everything that follows.
for pub in parser.publics:
key = (pub.section, pub.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
node_dict[key].set_decorated(pub.name)
for sizeref in parser.sizerefs:
key = (sizeref.section, sizeref.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
node_dict[key].section_contribution = sizeref.size
for glo in parser.globals:
key = (glo.section, glo.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
node_dict[key].node_type = SymbolType.DATA
node_dict[key].friendly_name = glo.name
try:
# Check our types database for type information.
# If we did not parse the TYPES section, we can only
# get information for built-in "T_" types.
g_info = parser.types.get(glo.type)
node_dict[key].confirmed_size = g_info.size
node_dict[key].data_type = g_info
# Previously we set the symbol type to POINTER here if
# the variable was known to be a pointer. We can derive this
# information later when it's time to compare the variable,
# so let's set these to symbol type DATA instead.
# POINTER will be reserved for non-variable pointer data.
# e.g. thunks, unwind section.
except (CvdumpKeyError, CvdumpIntegrityError):
# No big deal if we don't have complete type information.
pass
for key, _ in parser.lines.items():
# Here we only set if the section:offset already exists
# because our values include offsets inside of the function.
if key in node_dict:
node_dict[key].node_type = SymbolType.FUNCTION
# The LINES section contains every code line in the file, naturally.
# There isn't an obvious separation between functions, so we have to
# read everything. However, any function that would be in LINES
# has to be somewhere else in the PDB (probably PUBLICS).
# Isolate the lines that we actually care about for matching.
self.verified_lines = {
key: value for (key, value) in parser.lines.items() if key in node_dict
}
for sym in parser.symbols:
key = (sym.section, sym.offset)
if key not in node_dict:
node_dict[key] = CvdumpNode(*key)
if sym.type == "S_GPROC32":
node_dict[key].friendly_name = sym.name
node_dict[key].confirmed_size = sym.size
node_dict[key].node_type = SymbolType.FUNCTION
node_dict[key].symbol_entry = sym
self.nodes: List[CvdumpNode] = [
v for _, v in dict(sorted(node_dict.items())).items()
]
self._estimate_size()
def _estimate_size(self):
"""Get the distance between one section:offset value and the next one
in the same section. This gives a rough estimate of the size of the symbol.
If we have information from SECTION CONTRIBUTIONS, take whichever one is
less to get the best approximate size."""
for i in range(len(self.nodes) - 1):
this_node = self.nodes[i]
next_node = self.nodes[i + 1]
# If they are in different sections, we can't compare them
if this_node.section != next_node.section:
continue
this_node.estimated_size = next_node.offset - this_node.offset

View file

@ -1,121 +0,0 @@
"""For demangling a subset of MSVC mangled symbols.
Some unofficial information about the mangling scheme is here:
https://en.wikiversity.org/wiki/Visual_C%2B%2B_name_mangling
"""
import re
from collections import namedtuple
from typing import Optional
import pydemangler
class InvalidEncodedNumberError(Exception):
pass
_encoded_number_translate = str.maketrans("ABCDEFGHIJKLMNOP", "0123456789ABCDEF")
def parse_encoded_number(string: str) -> int:
# TODO: assert string ends in "@"?
if string.endswith("@"):
string = string[:-1]
try:
return int(string.translate(_encoded_number_translate), 16)
except ValueError as e:
raise InvalidEncodedNumberError(string) from e
string_const_regex = re.compile(
r"\?\?_C@\_(?P<is_utf16>[0-1])(?P<len>\d|[A-P]+@)(?P<hash>\w+)@(?P<value>.+)@"
)
StringConstInfo = namedtuple("StringConstInfo", "len is_utf16")
def demangle_string_const(symbol: str) -> Optional[StringConstInfo]:
"""Don't bother to decode the string text from the symbol.
We can just read it from the binary once we have the length."""
match = string_const_regex.match(symbol)
if match is None:
return None
try:
strlen = (
parse_encoded_number(match.group("len"))
if "@" in match.group("len")
else int(match.group("len"))
)
except (ValueError, InvalidEncodedNumberError):
return None
is_utf16 = match.group("is_utf16") == "1"
return StringConstInfo(len=strlen, is_utf16=is_utf16)
def get_vtordisp_name(symbol: str) -> Optional[str]:
# pylint: disable=c-extension-no-member
"""For adjuster thunk functions, the PDB will sometimes use a name
that contains "vtordisp" but often will just reuse the name of the
function being thunked. We want to use the vtordisp name if possible."""
name = pydemangler.demangle(symbol)
if name is None:
return None
if "`vtordisp" not in name:
return None
# Now we remove the parts of the friendly name that we don't need
try:
# Assuming this is the last of the function prefixes
thiscall_idx = name.index("__thiscall")
# To match the end of the `vtordisp{x,y}' string
end_idx = name.index("}'")
return name[thiscall_idx + 11 : end_idx + 2]
except ValueError:
return name
def demangle_vtable(symbol: str) -> str:
# pylint: disable=c-extension-no-member
"""Get the class name referenced in the vtable symbol."""
raw = pydemangler.demangle(symbol)
if raw is None:
pass # TODO: This shouldn't happen if MSVC behaves
# Remove storage class and other stuff we don't care about
return (
raw.replace("<class ", "<")
.replace("<struct ", "<")
.replace("const ", "")
.replace("volatile ", "")
)
def demangle_vtable_ourselves(symbol: str) -> str:
"""Parked implementation of MSVC symbol demangling.
We only use this for vtables and it works okay with the simple cases or
templates that refer to other classes/structs. Some namespace support.
Does not support backrefs, primitive types, or vtables with
virtual inheritance."""
# Seek ahead 4 chars to strip off "??_7" prefix
t = symbol[4:].split("@")
# "?$" indicates a template class
if t[0].startswith("?$"):
class_name = t[0][2:]
# PA = Pointer/reference
# V or U = class or struct
if t[1].startswith("PA"):
generic = f"{t[1][3:]} *"
else:
generic = t[1][1:]
return f"{class_name}<{generic}>::`vftable'"
# If we have two classes listed, it is a namespace hierarchy.
# @@6B@ is a common generic suffix for these vtable symbols.
if t[1] != "" and t[1] != "6B":
return t[1] + "::" + t[0] + "::`vftable'"
return t[0] + "::`vftable'"

View file

@ -1,182 +0,0 @@
import re
from typing import Iterable, Tuple
from collections import namedtuple
from .types import CvdumpTypesParser
from .symbols import CvdumpSymbolsParser
# e.g. `*** PUBLICS`
_section_change_regex = re.compile(r"\*\*\* (?P<section>[A-Z/ ]{2,})")
# e.g. ` 27 00034EC0 28 00034EE2 29 00034EE7 30 00034EF4`
_line_addr_pairs_findall = re.compile(r"\s+(?P<line_no>\d+) (?P<addr>[A-F0-9]{8})")
# We assume no spaces in the file name
# e.g. ` Z:\lego-island\isle\LEGO1\viewmanager\viewroi.cpp (None), 0001:00034E90-00034E97, line/addr pairs = 2`
_lines_subsection_header = re.compile(
r"^\s*(?P<filename>\S+).*?, (?P<section>[A-F0-9]{4}):(?P<start>[A-F0-9]{8})-(?P<end>[A-F0-9]{8}), line/addr pairs = (?P<len>\d+)"
)
# e.g. `S_PUB32: [0001:0003FF60], Flags: 00000000, __read`
_publics_line_regex = re.compile(
r"^(?P<type>\w+): \[(?P<section>\w{4}):(?P<offset>\w{8})], Flags: (?P<flags>\w{8}), (?P<name>\S+)"
)
# e.g. ` Debug start: 00000008, Debug end: 0000016E`
_gproc_debug_regex = re.compile(
r"\s*Debug start: (?P<start>\w{8}), Debug end: (?P<end>\w{8})"
)
# e.g. ` 00DA 0001:00000000 00000073 60501020`
_section_contrib_regex = re.compile(
r"\s*(?P<module>\w{4}) (?P<section>\w{4}):(?P<offset>\w{8}) (?P<size>\w{8}) (?P<flags>\w{8})"
)
# e.g. `S_GDATA32: [0003:000004A4], Type: T_32PRCHAR(0470), g_set`
_gdata32_regex = re.compile(
r"S_GDATA32: \[(?P<section>\w{4}):(?P<offset>\w{8})\], Type:\s*(?P<type>\S+), (?P<name>.+)"
)
# e.g. 0003 "CMakeFiles/isle.dir/ISLE/res/isle.rc.res"
# e.g. 0004 "C:\work\lego-island\isle\3rdparty\smartheap\SHLW32MT.LIB" "check.obj"
_module_regex = re.compile(r"(?P<id>\w{4})(?: \"(?P<lib>.+?)\")?(?: \"(?P<obj>.+?)\")")
# User functions only
LinesEntry = namedtuple("LinesEntry", "filename line_no section offset")
# Strings, vtables, functions
# superset of everything else
# only place you can find the C symbols (library functions, smacker, etc)
PublicsEntry = namedtuple("PublicsEntry", "type section offset flags name")
# (Estimated) size of any symbol
SizeRefEntry = namedtuple("SizeRefEntry", "module section offset size")
# global variables
GdataEntry = namedtuple("GdataEntry", "section offset type name")
ModuleEntry = namedtuple("ModuleEntry", "id lib obj")
class CvdumpParser:
# pylint: disable=too-many-instance-attributes
def __init__(self) -> None:
self._section: str = ""
self._lines_function: Tuple[str, int] = ("", 0)
self.lines = {}
self.publics = []
self.sizerefs = []
self.globals = []
self.modules = []
self.types = CvdumpTypesParser()
self.symbols_parser = CvdumpSymbolsParser()
@property
def symbols(self):
return self.symbols_parser.symbols
def _lines_section(self, line: str):
"""Parsing entries from the LINES section. We only care about the pairs of
line_number and address and the subsection header to indicate which code file
we are in."""
# Subheader indicates a new function and possibly a new code filename.
# Save the section here because it is not given on the lines that follow.
if (match := _lines_subsection_header.match(line)) is not None:
self._lines_function = (
match.group("filename"),
int(match.group("section"), 16),
)
return
# Match any pairs as we find them
for line_no, offset in _line_addr_pairs_findall.findall(line):
key = (self._lines_function[1], int(offset, 16))
self.lines[key] = (self._lines_function[0], int(line_no))
def _publics_section(self, line: str):
"""Match each line from PUBLICS and pull out the symbol information.
These are MSVC mangled symbol names. String constants and vtable
addresses can only be found here."""
if (match := _publics_line_regex.match(line)) is not None:
self.publics.append(
PublicsEntry(
type=match.group("type"),
section=int(match.group("section"), 16),
offset=int(match.group("offset"), 16),
flags=int(match.group("flags"), 16),
name=match.group("name"),
)
)
def _globals_section(self, line: str):
"""S_PROCREF may be useful later.
Right now we just want S_GDATA32 symbols because it is the simplest
way to access global variables."""
if (match := _gdata32_regex.match(line)) is not None:
self.globals.append(
GdataEntry(
section=int(match.group("section"), 16),
offset=int(match.group("offset"), 16),
type=match.group("type"),
name=match.group("name"),
)
)
def _section_contributions(self, line: str):
"""Gives the size of elements across all sections of the binary.
This is the easiest way to get the data size for .data and .rdata
members that do not have a primitive data type."""
if (match := _section_contrib_regex.match(line)) is not None:
self.sizerefs.append(
SizeRefEntry(
module=int(match.group("module"), 16),
section=int(match.group("section"), 16),
offset=int(match.group("offset"), 16),
size=int(match.group("size"), 16),
)
)
def _modules_section(self, line: str):
"""Record the object file (and lib file, if used) linked into the binary.
The auto-incrementing id is cross-referenced in SECTION CONTRIBUTIONS
(and perhaps other locations)"""
if (match := _module_regex.match(line)) is not None:
self.modules.append(
ModuleEntry(
id=int(match.group("id"), 16),
lib=match.group("lib"),
obj=match.group("obj"),
)
)
def read_line(self, line: str):
if (match := _section_change_regex.match(line)) is not None:
self._section = match.group(1)
return
if self._section == "TYPES":
self.types.read_line(line)
elif self._section == "SYMBOLS":
self.symbols_parser.read_line(line)
elif self._section == "LINES":
self._lines_section(line)
elif self._section == "PUBLICS":
self._publics_section(line)
elif self._section == "SECTION CONTRIBUTIONS":
self._section_contributions(line)
elif self._section == "GLOBALS":
self._globals_section(line)
elif self._section == "MODULES":
self._modules_section(line)
def read_lines(self, lines: Iterable[str]):
for line in lines:
self.read_line(line)

View file

@ -1,83 +0,0 @@
import io
from os import name as os_name
from enum import Enum
from typing import List
import subprocess
from isledecomp.lib import lib_path_join
from isledecomp.dir import winepath_unix_to_win
from .parser import CvdumpParser
class DumpOpt(Enum):
LINES = 0
SYMBOLS = 1
GLOBALS = 2
PUBLICS = 3
SECTION_CONTRIB = 4
MODULES = 5
TYPES = 6
cvdump_opt_map = {
DumpOpt.LINES: "-l",
DumpOpt.SYMBOLS: "-s",
DumpOpt.GLOBALS: "-g",
DumpOpt.PUBLICS: "-p",
DumpOpt.SECTION_CONTRIB: "-seccontrib",
DumpOpt.MODULES: "-m",
DumpOpt.TYPES: "-t",
}
class Cvdump:
def __init__(self, pdb: str) -> None:
self._pdb: str = pdb
self._options = set()
def lines(self):
self._options.add(DumpOpt.LINES)
return self
def symbols(self):
self._options.add(DumpOpt.SYMBOLS)
return self
def globals(self):
self._options.add(DumpOpt.GLOBALS)
return self
def publics(self):
self._options.add(DumpOpt.PUBLICS)
return self
def section_contributions(self):
self._options.add(DumpOpt.SECTION_CONTRIB)
return self
def modules(self):
self._options.add(DumpOpt.MODULES)
return self
def types(self):
self._options.add(DumpOpt.TYPES)
return self
def cmd_line(self) -> List[str]:
cvdump_exe = lib_path_join("cvdump.exe")
flags = [cvdump_opt_map[opt] for opt in self._options]
if os_name == "nt":
return [cvdump_exe, *flags, self._pdb]
return ["wine", cvdump_exe, *flags, winepath_unix_to_win(self._pdb)]
def run(self) -> CvdumpParser:
parser = CvdumpParser()
call = self.cmd_line()
with subprocess.Popen(call, stdout=subprocess.PIPE) as proc:
for line in io.TextIOWrapper(proc.stdout, encoding="utf-8"):
# Blank lines are there to help the reader; they have no context significance
if line != "\n":
parser.read_line(line)
return parser

View file

@ -1,162 +0,0 @@
from dataclasses import dataclass, field
import logging
import re
from re import Match
from typing import NamedTuple, Optional
logger = logging.getLogger(__name__)
class StackOrRegisterSymbol(NamedTuple):
symbol_type: str
location: str
"""Should always be set/converted to lowercase."""
data_type: str
name: str
# S_GPROC32 = functions
@dataclass
class SymbolsEntry:
# pylint: disable=too-many-instance-attributes
type: str
section: int
offset: int
size: int
func_type: str
name: str
stack_symbols: list[StackOrRegisterSymbol] = field(default_factory=list)
frame_pointer_present: bool = False
addr: Optional[int] = None # Absolute address. Will be set later, if at all
class CvdumpSymbolsParser:
_symbol_line_generic_regex = re.compile(
r"\(\w+\)\s+(?P<symbol_type>[^\s:]+)(?::\s+(?P<second_part>\S.*))?|(?::)$"
)
"""
Parses the first part, e.g. `(00008C) S_GPROC32`, and splits off the second part after the colon (if it exists).
There are three cases:
- no colon, e.g. `(000350) S_END`
- colon but no data, e.g. `(000370) S_COMPILE:`
- colon and data, e.g. `(000304) S_REGISTER: esi, Type: 0x1E14, this``
"""
_symbol_line_function_regex = re.compile(
r"\[(?P<section>\w{4}):(?P<offset>\w{8})\], Cb: (?P<size>\w+), Type:\s+(?P<func_type>[^\s,]+), (?P<name>.+)"
)
"""
Parses the second part of a function symbol, e.g.
`[0001:00034E90], Cb: 00000007, Type: 0x1024, ViewROI::IntrinsicImportance`
"""
# the second part of e.g.
_stack_register_symbol_regex = re.compile(
r"(?P<location>\S+), Type:\s+(?P<data_type>[\w()]+), (?P<name>.+)$"
)
"""
Parses the second part of a stack or register symbol, e.g.
`esi, Type: 0x1E14, this`
"""
_debug_start_end_regex = re.compile(
r"^\s*Debug start: (?P<debug_start>\w+), Debug end: (?P<debug_end>\w+)$"
)
_parent_end_next_regex = re.compile(
r"\s*Parent: (?P<parent_addr>\w+), End: (?P<end_addr>\w+), Next: (?P<next_addr>\w+)$"
)
_flags_frame_pointer_regex = re.compile(r"\s*Flags: Frame Ptr Present$")
_register_stack_symbols = ["S_BPREL32", "S_REGISTER"]
# List the unhandled types so we can check exhaustiveness
_unhandled_symbols = [
"S_COMPILE",
"S_OBJNAME",
"S_THUNK32",
"S_LABEL32",
"S_LDATA32",
"S_UDT",
]
"""Parser for cvdump output, SYMBOLS section."""
def __init__(self):
self.symbols: list[SymbolsEntry] = []
self.current_function: Optional[SymbolsEntry] = None
# If we read an S_BLOCK32 node, increment this level.
# This is so we do not end the proc early by reading an S_END
# that indicates the end of the block.
self.block_level: int = 0
def read_line(self, line: str):
if (match := self._symbol_line_generic_regex.match(line)) is not None:
self._parse_generic_case(line, match)
elif (match := self._parent_end_next_regex.match(line)) is not None:
# We do not need this info at the moment, might be useful in the future
pass
elif (match := self._debug_start_end_regex.match(line)) is not None:
# We do not need this info at the moment, might be useful in the future
pass
elif (match := self._flags_frame_pointer_regex.match(line)) is not None:
if self.current_function is None:
logger.error(
"Found a `Flags: Frame Ptr Present` but self.current_function is None"
)
return
self.current_function.frame_pointer_present = True
else:
# Most of these are either `** Module: [...]` or data we do not care about
logger.debug("Unhandled line: %s", line[:-1])
def _parse_generic_case(self, line, line_match: Match[str]):
symbol_type: str = line_match.group("symbol_type")
second_part: Optional[str] = line_match.group("second_part")
if symbol_type in ["S_GPROC32", "S_LPROC32"]:
assert second_part is not None
if (match := self._symbol_line_function_regex.match(second_part)) is None:
logger.error("Invalid function symbol: %s", line[:-1])
return
self.current_function = SymbolsEntry(
type=symbol_type,
section=int(match.group("section"), 16),
offset=int(match.group("offset"), 16),
size=int(match.group("size"), 16),
func_type=match.group("func_type"),
name=match.group("name"),
)
self.symbols.append(self.current_function)
elif symbol_type in self._register_stack_symbols:
assert second_part is not None
if self.current_function is None:
logger.error("Found stack/register outside of function: %s", line[:-1])
return
if (match := self._stack_register_symbol_regex.match(second_part)) is None:
logger.error("Invalid stack/register symbol: %s", line[:-1])
return
new_symbol = StackOrRegisterSymbol(
symbol_type=symbol_type,
location=match.group("location").lower(),
data_type=match.group("data_type"),
name=match.group("name"),
)
self.current_function.stack_symbols.append(new_symbol)
elif symbol_type == "S_BLOCK32":
self.block_level += 1
elif symbol_type == "S_END":
if self.block_level > 0:
self.block_level -= 1
assert self.block_level >= 0
else:
self.current_function = None
elif symbol_type in self._unhandled_symbols:
return
else:
logger.error("Unhandled symbol type: %s", line)

View file

@ -1,737 +0,0 @@
from dataclasses import dataclass
import re
import logging
from typing import Any, Dict, List, NamedTuple, Optional
logger = logging.getLogger(__name__)
class CvdumpTypeError(Exception):
pass
class CvdumpKeyError(KeyError):
pass
class CvdumpIntegrityError(Exception):
pass
class FieldListItem(NamedTuple):
"""Member of a class or structure"""
offset: int
name: str
type: str
@dataclass
class VirtualBaseClass:
type: str
index: int
direct: bool
@dataclass
class VirtualBasePointer:
vboffset: int
bases: list[VirtualBaseClass]
class ScalarType(NamedTuple):
offset: int
name: Optional[str]
type: str
@property
def size(self) -> int:
return scalar_type_size(self.type)
@property
def format_char(self) -> str:
return scalar_type_format_char(self.type)
@property
def is_pointer(self) -> bool:
return scalar_type_pointer(self.type)
class TypeInfo(NamedTuple):
key: str
size: Optional[int]
name: Optional[str] = None
members: Optional[List[FieldListItem]] = None
def is_scalar(self) -> bool:
# TODO: distinction between a class with zero members and no vtable?
return self.members is None
def normalize_type_id(key: str) -> str:
"""Helper for TYPES parsing to ensure a consistent format.
If key begins with "T_" it is a built-in type.
Else it is a hex string. We prefer lower case letters and
no leading zeroes. (UDT identifier pads to 8 characters.)"""
if key[0] == "0":
return f"0x{key[-4:].lower()}"
# Remove numeric value for "T_" type. We don't use this.
return key.partition("(")[0]
def scalar_type_pointer(type_name: str) -> bool:
return type_name.startswith("T_32P")
def scalar_type_size(type_name: str) -> int:
if scalar_type_pointer(type_name):
return 4
if "CHAR" in type_name:
return 2 if "WCHAR" in type_name else 1
if "SHORT" in type_name:
return 2
if "QUAD" in type_name or "64" in type_name:
return 8
return 4
def scalar_type_signed(type_name: str) -> bool:
if scalar_type_pointer(type_name):
return False
# According to cvinfo.h, T_WCHAR is unsigned
return not type_name.startswith("T_U") and not type_name.startswith("T_W")
def scalar_type_format_char(type_name: str) -> str:
if scalar_type_pointer(type_name):
return "L"
# "Really a char"
if type_name.startswith("T_RCHAR"):
return "c"
# floats
if type_name.startswith("T_REAL"):
return "d" if "64" in type_name else "f"
size = scalar_type_size(type_name)
char = ({1: "b", 2: "h", 4: "l", 8: "q"}).get(size, "l")
return char if scalar_type_signed(type_name) else char.upper()
def member_list_to_struct_string(members: List[ScalarType]) -> str:
"""Create a string for use with struct.unpack"""
format_string = "".join(m.format_char for m in members)
if len(format_string) > 0:
return "<" + format_string
return ""
def join_member_names(parent: str, child: Optional[str]) -> str:
"""Helper method to combine parent/child member names.
Child member name is None if the child is a scalar type."""
if child is None:
return parent
# If the child is an array index, join without the dot
if child.startswith("["):
return f"{parent}{child}"
return f"{parent}.{child}"
class CvdumpTypesParser:
"""Parser for cvdump output, TYPES section.
Tricky enough that it demands its own parser."""
# Marks the start of a new type
INDEX_RE = re.compile(r"(?P<key>0x\w+) : .* (?P<type>LF_\w+)")
# LF_FIELDLIST class/struct member (1/2)
LIST_RE = re.compile(
r"\s+list\[\d+\] = LF_MEMBER, (?P<scope>\w+), type = (?P<type>.*), offset = (?P<offset>\d+)"
)
# LF_FIELDLIST vtable indicator
VTABLE_RE = re.compile(r"^\s+list\[\d+\] = LF_VFUNCTAB")
# LF_FIELDLIST superclass indicator
SUPERCLASS_RE = re.compile(
r"^\s+list\[\d+\] = LF_BCLASS, (?P<scope>\w+), type = (?P<type>.*), offset = (?P<offset>\d+)"
)
# LF_FIELDLIST virtual direct/indirect base pointer, line 1/2
VBCLASS_RE = re.compile(
r"^\s+list\[\d+\] = LF_(?P<indirect>I?)VBCLASS, .* base type = (?P<type>.*)$"
)
# LF_FIELDLIST virtual direct/indirect base pointer, line 2/2
VBCLASS_LINE_2_RE = re.compile(
r"^\s+virtual base ptr = .+, vbpoff = (?P<vboffset>\d+), vbind = (?P<vbindex>\d+)$"
)
# LF_FIELDLIST member name (2/2)
MEMBER_RE = re.compile(r"^\s+member name = '(?P<name>.*)'$")
LF_FIELDLIST_ENUMERATE = re.compile(
r"^\s+list\[\d+\] = LF_ENUMERATE,.*value = (?P<value>\d+), name = '(?P<name>[^']+)'$"
)
# LF_ARRAY element type
ARRAY_ELEMENT_RE = re.compile(r"^\s+Element type = (?P<type>.*)")
# LF_ARRAY total array size
ARRAY_LENGTH_RE = re.compile(r"^\s+length = (?P<length>\d+)")
# LF_CLASS/LF_STRUCTURE field list reference
CLASS_FIELD_RE = re.compile(
r"^\s+# members = \d+, field list type (?P<field_type>0x\w+),"
)
# LF_CLASS/LF_STRUCTURE name and other info
CLASS_NAME_RE = re.compile(
r"^\s+Size = (?P<size>\d+), class name = (?P<name>(?:[^,]|,\S)+)(?:, UDT\((?P<udt>0x\w+)\))?"
)
# LF_MODIFIER, type being modified
MODIFIES_RE = re.compile(r".*modifies type (?P<type>.*)$")
# LF_ARGLIST number of entries
LF_ARGLIST_ARGCOUNT = re.compile(r".*argument count = (?P<argcount>\d+)$")
# LF_ARGLIST list entry
LF_ARGLIST_ENTRY = re.compile(
r"^\s+list\[(?P<index>\d+)\] = (?P<arg_type>[\w()]+)$"
)
# LF_POINTER element
LF_POINTER_ELEMENT = re.compile(r"^\s+Element type : (?P<element_type>.+)$")
# LF_MFUNCTION attribute key-value pairs
LF_MFUNCTION_ATTRIBUTES = [
re.compile(r"\s*Return type = (?P<return_type>[\w()]+)$"),
re.compile(r"\s*Class type = (?P<class_type>[\w()]+)$"),
re.compile(r"\s*This type = (?P<this_type>[\w()]+)$"),
# Call type may contain whitespace
re.compile(r"\s*Call type = (?P<call_type>[\w()\s]+)$"),
re.compile(r"\s*Parms = (?P<num_params>[\w()]+)$"), # LF_MFUNCTION only
re.compile(r"\s*# Parms = (?P<num_params>[\w()]+)$"), # LF_PROCEDURE only
re.compile(r"\s*Arg list type = (?P<arg_list_type>[\w()]+)$"),
re.compile(
r"\s*This adjust = (?P<this_adjust>[\w()]+)$"
), # By how much the incoming pointers are shifted in virtual inheritance; hex value without `0x` prefix
re.compile(
r"\s*Func attr = (?P<func_attr>[\w()]+)$"
), # Only for completeness, is always `none`
]
LF_ENUM_ATTRIBUTES = [
re.compile(r"^\s*# members = (?P<num_members>\d+)$"),
re.compile(r"^\s*enum name = (?P<name>.+)$"),
]
LF_ENUM_TYPES = re.compile(
r"^\s*type = (?P<underlying_type>\S+) field list type (?P<field_type>0x\w{4})$"
)
LF_ENUM_UDT = re.compile(r"^\s*UDT\((?P<udt>0x\w+)\)$")
LF_UNION_LINE = re.compile(
r"^.*field list type (?P<field_type>0x\w+),.*Size = (?P<size>\d+)\s*,class name = (?P<name>(?:[^,]|,\S)+)(?:,\s.*UDT\((?P<udt>0x\w+)\))?$"
)
MODES_OF_INTEREST = {
"LF_ARRAY",
"LF_CLASS",
"LF_ENUM",
"LF_FIELDLIST",
"LF_MODIFIER",
"LF_POINTER",
"LF_STRUCTURE",
"LF_ARGLIST",
"LF_MFUNCTION",
"LF_PROCEDURE",
"LF_UNION",
}
def __init__(self) -> None:
self.mode: Optional[str] = None
self.last_key = ""
self.keys: Dict[str, Dict[str, Any]] = {}
def _new_type(self):
"""Prepare a new dict for the type we just parsed.
The id is self.last_key and the "type" of type is self.mode.
e.g. LF_CLASS"""
self.keys[self.last_key] = {"type": self.mode}
def _set(self, key: str, value):
self.keys[self.last_key][key] = value
def _add_member(self, offset: int, type_: str):
obj = self.keys[self.last_key]
if "members" not in obj:
obj["members"] = []
obj["members"].append({"offset": offset, "type": type_})
def _set_member_name(self, name: str):
"""Set name for most recently added member."""
obj = self.keys[self.last_key]
obj["members"][-1]["name"] = name
def _add_variant(self, name: str, value: int):
obj = self.keys[self.last_key]
if "variants" not in obj:
obj["variants"] = []
variants: list[dict[str, Any]] = obj["variants"]
variants.append({"name": name, "value": value})
def _get_field_list(self, type_obj: Dict[str, Any]) -> List[FieldListItem]:
"""Return the field list for the given LF_CLASS/LF_STRUCTURE reference"""
if type_obj.get("type") == "LF_FIELDLIST":
field_obj = type_obj
else:
field_list_type = type_obj["field_list_type"]
field_obj = self.keys[field_list_type]
members: List[FieldListItem] = []
super_ids = field_obj.get("super", [])
for super_id in super_ids:
# May need to resolve forward ref.
superclass = self.get(super_id)
if superclass.members is not None:
members += superclass.members
raw_members = field_obj.get("members", [])
members += [
FieldListItem(
offset=m["offset"],
type=m["type"],
name=m["name"],
)
for m in raw_members
]
return sorted(members, key=lambda m: m.offset)
def _mock_array_members(self, type_obj: Dict) -> List[FieldListItem]:
"""LF_ARRAY elements provide the element type and the total size.
We want the list of "members" as if this was a struct."""
if type_obj.get("type") != "LF_ARRAY":
raise CvdumpTypeError("Type is not an LF_ARRAY")
array_type = type_obj.get("array_type")
if array_type is None:
raise CvdumpIntegrityError("No array element type")
array_element_size = self.get(array_type).size
assert (
array_element_size is not None
), "Encountered an array whose type has no size"
n_elements = type_obj["size"] // array_element_size
return [
FieldListItem(
offset=i * array_element_size,
type=array_type,
name=f"[{i}]",
)
for i in range(n_elements)
]
def get(self, type_key: str) -> TypeInfo:
"""Convert our dictionary values read from the cvdump output
into a consistent format for the given type."""
# Scalar type. Handled here because it makes the recursive steps
# much simpler.
if type_key.startswith("T_"):
size = scalar_type_size(type_key)
return TypeInfo(
key=type_key,
size=size,
)
# Go to our dictionary to find it.
obj = self.keys.get(type_key.lower())
if obj is None:
raise CvdumpKeyError(type_key)
# These type references are just a wrapper around a scalar
if obj.get("type") == "LF_ENUM":
underlying_type = obj.get("underlying_type")
if underlying_type is None:
raise CvdumpKeyError(f"Missing 'underlying_type' in {obj}")
return self.get(underlying_type)
if obj.get("type") == "LF_POINTER":
return self.get("T_32PVOID")
if obj.get("is_forward_ref", False):
# Get the forward reference to follow.
# If this is LF_CLASS/LF_STRUCTURE, it is the UDT value.
# For LF_MODIFIER, it is the type being modified.
forward_ref = obj.get("udt", None) or obj.get("modifies", None)
if forward_ref is None:
raise CvdumpIntegrityError(f"Null forward ref for type {type_key}")
return self.get(forward_ref)
# Else it is not a forward reference, so build out the object here.
if obj.get("type") == "LF_ARRAY":
members = self._mock_array_members(obj)
else:
members = self._get_field_list(obj)
return TypeInfo(
key=type_key,
size=obj.get("size"),
name=obj.get("name"),
members=members,
)
def get_by_name(self, name: str) -> TypeInfo:
"""Find the complex type with the given name."""
# TODO
raise NotImplementedError
def get_scalars(self, type_key: str) -> List[ScalarType]:
"""Reduce the given type to a list of scalars so we can
compare each component value."""
obj = self.get(type_key)
if obj.is_scalar():
# Use obj.key here for alias types like LF_POINTER
return [ScalarType(offset=0, type=obj.key, name=None)]
# mypy?
assert obj.members is not None
# Dedupe repeated offsets if this is a union type
unique_offsets = {m.offset: m for m in obj.members}
unique_members = [m for _, m in unique_offsets.items()]
return [
ScalarType(
offset=m.offset + cm.offset,
type=cm.type,
name=join_member_names(m.name, cm.name),
)
for m in unique_members
for cm in self.get_scalars(m.type)
]
def get_scalars_gapless(self, type_key: str) -> List[ScalarType]:
"""Reduce the given type to a list of scalars so we can
compare each component value."""
obj = self.get(type_key)
total_size = obj.size
assert (
total_size is not None
), "Called get_scalar_gapless() on a type without size"
scalars = self.get_scalars(type_key)
output = []
last_extent = total_size
# Walk the scalar list in reverse; we assume a gap could not
# come at the start of the struct.
for scalar in scalars[::-1]:
this_extent = scalar.offset + scalar_type_size(scalar.type)
size_diff = last_extent - this_extent
# We need to add the gap fillers in reverse here
for i in range(size_diff - 1, -1, -1):
# Push to front
output.insert(
0,
ScalarType(
offset=this_extent + i,
name="(padding)",
type="T_UCHAR",
),
)
output.insert(0, scalar)
last_extent = scalar.offset
return output
def get_format_string(self, type_key: str) -> str:
members = self.get_scalars_gapless(type_key)
return member_list_to_struct_string(members)
def read_line(self, line: str):
if line.endswith("\n"):
line = line[:-1]
if len(line) == 0:
return
if (match := self.INDEX_RE.match(line)) is not None:
type_ = match.group(2)
if type_ not in self.MODES_OF_INTEREST:
self.mode = None
return
# Don't need to normalize, it's already in the format we want
self.last_key = match.group(1)
self.mode = type_
self._new_type()
if type_ == "LF_ARGLIST":
submatch = self.LF_ARGLIST_ARGCOUNT.match(line)
assert submatch is not None
self.keys[self.last_key]["argcount"] = int(submatch.group("argcount"))
# TODO: This should be validated in another pass
return
if self.mode is None:
return
if self.mode == "LF_MODIFIER":
if (match := self.MODIFIES_RE.match(line)) is not None:
# For convenience, because this is essentially the same thing
# as an LF_CLASS forward ref.
self._set("is_forward_ref", True)
self._set("modifies", normalize_type_id(match.group("type")))
elif self.mode == "LF_ARRAY":
if (match := self.ARRAY_ELEMENT_RE.match(line)) is not None:
self._set("array_type", normalize_type_id(match.group("type")))
elif (match := self.ARRAY_LENGTH_RE.match(line)) is not None:
self._set("size", int(match.group("length")))
elif self.mode == "LF_FIELDLIST":
self.read_fieldlist_line(line)
elif self.mode == "LF_ARGLIST":
self.read_arglist_line(line)
elif self.mode in ["LF_MFUNCTION", "LF_PROCEDURE"]:
self.read_mfunction_line(line)
elif self.mode in ["LF_CLASS", "LF_STRUCTURE"]:
self.read_class_or_struct_line(line)
elif self.mode == "LF_POINTER":
self.read_pointer_line(line)
elif self.mode == "LF_ENUM":
self.read_enum_line(line)
elif self.mode == "LF_UNION":
self.read_union_line(line)
else:
# Check for exhaustiveness
logger.error("Unhandled data in mode: %s", self.mode)
def read_fieldlist_line(self, line: str):
# If this class has a vtable, create a mock member at offset 0
if (match := self.VTABLE_RE.match(line)) is not None:
# For our purposes, any pointer type will do
self._add_member(0, "T_32PVOID")
self._set_member_name("vftable")
# Superclass is set here in the fieldlist rather than in LF_CLASS
elif (match := self.SUPERCLASS_RE.match(line)) is not None:
superclass_list: dict[str, int] = self.keys[self.last_key].setdefault(
"super", {}
)
superclass_list[normalize_type_id(match.group("type"))] = int(
match.group("offset")
)
# virtual base class (direct or indirect)
elif (match := self.VBCLASS_RE.match(line)) is not None:
virtual_base_pointer = self.keys[self.last_key].setdefault(
"vbase",
VirtualBasePointer(
vboffset=-1, # default to -1 until we parse the correct value
bases=[],
),
)
assert isinstance(
virtual_base_pointer, VirtualBasePointer
) # type checker only
virtual_base_pointer.bases.append(
VirtualBaseClass(
type=match.group("type"),
index=-1, # default to -1 until we parse the correct value
direct=match.group("indirect") != "I",
)
)
elif (match := self.VBCLASS_LINE_2_RE.match(line)) is not None:
virtual_base_pointer = self.keys[self.last_key].get("vbase", None)
assert isinstance(
virtual_base_pointer, VirtualBasePointer
), "Parsed the second line of an (I)VBCLASS without the first one"
vboffset = int(match.group("vboffset"))
if virtual_base_pointer.vboffset == -1:
# default value
virtual_base_pointer.vboffset = vboffset
elif virtual_base_pointer.vboffset != vboffset:
# vboffset is always equal to 4 in our examples. We are not sure if there can be multiple
# virtual base pointers, and if so, how the layout is supposed to look.
# We therefore assume that there is always only one virtual base pointer.
logger.error(
"Unhandled: Found multiple virtual base pointers at offsets %d and %d",
virtual_base_pointer.vboffset,
vboffset,
)
virtual_base_pointer.bases[-1].index = int(match.group("vbindex"))
# these come out of order, and the lists are so short that it's fine to sort them every time
virtual_base_pointer.bases.sort(key=lambda x: x.index)
# Member offset and type given on the first of two lines.
elif (match := self.LIST_RE.match(line)) is not None:
self._add_member(
int(match.group("offset")), normalize_type_id(match.group("type"))
)
# Name of the member read on the second of two lines.
elif (match := self.MEMBER_RE.match(line)) is not None:
self._set_member_name(match.group("name"))
elif (match := self.LF_FIELDLIST_ENUMERATE.match(line)) is not None:
self._add_variant(match.group("name"), int(match.group("value")))
def read_class_or_struct_line(self, line: str):
# Match the reference to the associated LF_FIELDLIST
if (match := self.CLASS_FIELD_RE.match(line)) is not None:
if match.group("field_type") == "0x0000":
# Not redundant. UDT might not match the key.
# These cases get reported as UDT mismatch.
self._set("is_forward_ref", True)
else:
field_list_type = normalize_type_id(match.group("field_type"))
self._set("field_list_type", field_list_type)
elif line.lstrip().startswith("Derivation list type"):
# We do not care about the second line, but we still match it so we see an error
# when another line fails to match
pass
elif (match := self.CLASS_NAME_RE.match(line)) is not None:
# Last line has the vital information.
# If this is a FORWARD REF, we need to follow the UDT pointer
# to get the actual class details.
self._set("name", match.group("name"))
udt = match.group("udt")
if udt is not None:
self._set("udt", normalize_type_id(udt))
self._set("size", int(match.group("size")))
else:
logger.error("Unmatched line in class: %s", line[:-1])
def read_arglist_line(self, line: str):
if (match := self.LF_ARGLIST_ENTRY.match(line)) is not None:
obj = self.keys[self.last_key]
arglist: list = obj.setdefault("args", [])
assert int(match.group("index")) == len(
arglist
), "Argument list out of sync"
arglist.append(match.group("arg_type"))
else:
logger.error("Unmatched line in arglist: %s", line[:-1])
def read_pointer_line(self, line: str):
if (match := self.LF_POINTER_ELEMENT.match(line)) is not None:
self._set("element_type", match.group("element_type"))
else:
stripped_line = line.strip()
# We don't parse these lines, but we still want to check for exhaustiveness
# in case we missed some relevant data
if not any(
stripped_line.startswith(prefix)
for prefix in ["Pointer", "const Pointer", "L-value", "volatile"]
):
logger.error("Unrecognized pointer attribute: %s", line[:-1])
def read_mfunction_line(self, line: str):
"""
The layout is not consistent, so we want to be as robust as possible here.
- Example 1:
Return type = T_LONG(0012), Call type = C Near
Func attr = none
- Example 2:
Return type = T_CHAR(0010), Class type = 0x101A, This type = 0x101B,
Call type = ThisCall, Func attr = none
"""
obj = self.keys[self.last_key]
key_value_pairs = line.split(",")
for pair in key_value_pairs:
if pair.isspace():
continue
obj |= self.parse_function_attribute(pair)
def parse_function_attribute(self, pair: str) -> dict[str, str]:
for attribute_regex in self.LF_MFUNCTION_ATTRIBUTES:
if (match := attribute_regex.match(pair)) is not None:
return match.groupdict()
logger.error("Unknown attribute in function: %s", pair)
return {}
def read_enum_line(self, line: str):
obj = self.keys[self.last_key]
# We need special comma handling because commas may appear in the name.
# Splitting by "," yields the wrong result.
enum_attributes = line.split(", ")
for pair in enum_attributes:
if pair.endswith(","):
pair = pair[:-1]
if pair.isspace():
continue
obj |= self.parse_enum_attribute(pair)
def parse_enum_attribute(self, attribute: str) -> dict[str, Any]:
for attribute_regex in self.LF_ENUM_ATTRIBUTES:
if (match := attribute_regex.match(attribute)) is not None:
return match.groupdict()
if attribute == "NESTED":
return {"is_nested": True}
if attribute == "FORWARD REF":
return {"is_forward_ref": True}
if attribute.startswith("UDT"):
match = self.LF_ENUM_UDT.match(attribute)
assert match is not None
return {"udt": normalize_type_id(match.group("udt"))}
if (match := self.LF_ENUM_TYPES.match(attribute)) is not None:
result = match.groupdict()
result["underlying_type"] = normalize_type_id(result["underlying_type"])
return result
logger.error("Unknown attribute in enum: %s", attribute)
return {}
def read_union_line(self, line: str):
"""This is a rather barebones handler, only parsing the size"""
if (match := self.LF_UNION_LINE.match(line)) is None:
raise AssertionError(f"Unhandled in union: {line}")
self._set("name", match.group("name"))
if match.group("field_type") == "0x0000":
self._set("is_forward_ref", True)
self._set("size", int(match.group("size")))
if match.group("udt") is not None:
self._set("udt", normalize_type_id(match.group("udt")))

View file

@ -1,103 +0,0 @@
import os
import subprocess
import sys
import pathlib
from typing import Iterator
def winepath_win_to_unix(path: str) -> str:
return subprocess.check_output(["winepath", path], text=True).strip()
def winepath_unix_to_win(path: str) -> str:
return subprocess.check_output(["winepath", "-w", path], text=True).strip()
class PathResolver:
"""Intended to resolve Windows/Wine paths used in the PDB (cvdump) output
into a "canonical" format to be matched against code file paths from os.walk.
MSVC may include files from the parent dir using `..`. We eliminate those and create
an absolute path so that information about the same file under different names
will be combined into the same record. (i.e. line_no/addr pairs from LINES section.)
"""
def __init__(self, basedir) -> None:
"""basedir is the root path of the code directory in the format for your OS.
We will convert it to a PureWindowsPath to be platform-independent
and match that to the paths from the PDB."""
# Memoize the converted paths. We will need to do this for each path
# in the PDB, for each function in that file. (i.e. lots of repeated work)
self._memo = {}
# Convert basedir to an absolute path if it is not already.
# If it is not absolute, we cannot do the path swap on unix.
self._realdir = pathlib.Path(basedir).resolve()
self._is_unix = os.name != "nt"
if self._is_unix:
self._basedir = pathlib.PureWindowsPath(
winepath_unix_to_win(str(self._realdir))
)
else:
self._basedir = self._realdir
def _memo_wrapper(self, path_str: str) -> str:
"""Wrapper so we can memoize from the public caller method"""
path = pathlib.PureWindowsPath(path_str)
if not path.is_absolute():
# pathlib syntactic sugar for path concat
path = self._basedir / path
if self._is_unix:
# If the given path is relative to the basedir, deconstruct the path
# and swap in our unix path to avoid an expensive call to winepath.
try:
# Will raise ValueError if we are not relative to the base.
section = path.relative_to(self._basedir)
# Should combine to pathlib.PosixPath
mockpath = (self._realdir / section).resolve()
if mockpath.is_file():
return str(mockpath)
except ValueError:
pass
# We are not relative to the basedir, or our path swap attempt
# did not point at an actual file. Either way, we are forced
# to call winepath using our original path.
return winepath_win_to_unix(str(path))
# We must be on Windows. Convert back to WindowsPath.
# The resolve() call will eliminate intermediate backdir references.
return str(pathlib.Path(path).resolve())
def resolve_cvdump(self, path_str: str) -> str:
"""path_str is in Windows/Wine path format.
We will return a path in the format for the host OS."""
if path_str not in self._memo:
self._memo[path_str] = self._memo_wrapper(path_str)
return self._memo[path_str]
def is_file_cpp(filename: str) -> bool:
(_, ext) = os.path.splitext(filename)
return ext.lower() in (".h", ".cpp")
def walk_source_dir(source: str, recursive: bool = True) -> Iterator[str]:
"""Generator to walk the given directory recursively and return
any C++ files found."""
source = os.path.abspath(source)
for subdir, _, files in os.walk(source):
for file in files:
if is_file_cpp(file):
yield os.path.join(subdir, file)
if not recursive:
break
def get_file_in_script_dir(fn):
return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn)

View file

@ -1,13 +0,0 @@
"""Provides a reference point for redistributed tools found in this directory.
This allows you to get the path for these tools from a script run anywhere."""
from os.path import join, dirname
def lib_path() -> str:
"""Returns the directory for this module."""
return dirname(__file__)
def lib_path_join(name: str) -> str:
"""Convenience wrapper for os.path.join."""
return join(lib_path(), name)

View file

@ -1,3 +0,0 @@
from .codebase import DecompCodebase
from .parser import DecompParser
from .linter import DecompLinter

View file

@ -1,57 +0,0 @@
"""For aggregating decomp markers read from an entire directory and for a single module."""
from typing import Callable, Iterable, Iterator, List
from .parser import DecompParser
from .node import (
ParserSymbol,
ParserFunction,
ParserVtable,
ParserVariable,
ParserString,
)
class DecompCodebase:
def __init__(self, filenames: Iterable[str], module: str) -> None:
self._symbols: List[ParserSymbol] = []
parser = DecompParser()
for filename in filenames:
parser.reset()
with open(filename, "r", encoding="utf-8") as f:
parser.read_lines(f)
for sym in parser.iter_symbols(module):
sym.filename = filename
self._symbols.append(sym)
def prune_invalid_addrs(self, is_valid: Callable[int, bool]) -> List[ParserSymbol]:
"""Some decomp annotations might have an invalid address.
Return the list of addresses where we fail the is_valid check,
and remove those from our list of symbols."""
invalid_symbols = [sym for sym in self._symbols if not is_valid(sym.offset)]
self._symbols = [sym for sym in self._symbols if is_valid(sym.offset)]
return invalid_symbols
def iter_line_functions(self) -> Iterator[ParserFunction]:
"""Return lineref functions separately from nameref. Assuming the PDB matches
the state of the source code, a line reference is a guaranteed match, even if
multiple functions share the same name. (i.e. polymorphism)"""
return filter(
lambda s: isinstance(s, ParserFunction) and not s.is_nameref(),
self._symbols,
)
def iter_name_functions(self) -> Iterator[ParserFunction]:
return filter(
lambda s: isinstance(s, ParserFunction) and s.is_nameref(), self._symbols
)
def iter_vtables(self) -> Iterator[ParserVtable]:
return filter(lambda s: isinstance(s, ParserVtable), self._symbols)
def iter_variables(self) -> Iterator[ParserVariable]:
return filter(lambda s: isinstance(s, ParserVariable), self._symbols)
def iter_strings(self) -> Iterator[ParserString]:
return filter(lambda s: isinstance(s, ParserString), self._symbols)

View file

@ -1,97 +0,0 @@
from enum import Enum
from typing import Optional
from dataclasses import dataclass
# TODO: poorly chosen name, should be AlertType or AlertCode or something
class ParserError(Enum):
# WARN: Stub function exceeds some line number threshold
UNLIKELY_STUB = 100
# WARN: Decomp marker is close enough to be recognized, but does not follow syntax exactly
BAD_DECOMP_MARKER = 101
# WARN: Multiple markers in sequence do not have distinct modules
DUPLICATE_MODULE = 102
# WARN: Detected a dupcliate module/offset pair in the current file
DUPLICATE_OFFSET = 103
# WARN: We read a line that matches the decomp marker pattern, but we are not set up
# to handle it
BOGUS_MARKER = 104
# WARN: New function marker appeared while we were inside a function
MISSED_END_OF_FUNCTION = 105
# WARN: If we find a curly brace right after the function declaration
# this is wrong but we still have enough to make a match with reccmp
MISSED_START_OF_FUNCTION = 106
# WARN: A blank line appeared between the end of FUNCTION markers
# and the start of the function. We can ignore it, but the line shouldn't be there
UNEXPECTED_BLANK_LINE = 107
# WARN: We called the finish() method for the parser but had not reached the starting
# state of SEARCH
UNEXPECTED_END_OF_FILE = 108
# WARN: We found a marker to be referenced by name outside of a header file.
BYNAME_FUNCTION_IN_CPP = 109
# WARN: A GLOBAL marker appeared over a variable without the g_ prefix
GLOBAL_MISSING_PREFIX = 110
# WARN: GLOBAL marker points at something other than variable declaration.
# We can't match global variables based on position, but the goal here is
# to ignore things like string literal that are not variables.
GLOBAL_NOT_VARIABLE = 111
# WARN: A marked static variable inside a function needs to have its
# function marked too, and in the same module.
ORPHANED_STATIC_VARIABLE = 112
# This code or higher is an error, not a warning
DECOMP_ERROR_START = 200
# ERROR: We found a marker unexpectedly
UNEXPECTED_MARKER = 200
# ERROR: We found a marker where we expected to find one, but it is incompatible
# with the preceding markers.
# For example, a GLOBAL cannot follow FUNCTION/STUB
INCOMPATIBLE_MARKER = 201
# ERROR: The line following an explicit by-name marker was not a comment
# We assume a syntax error here rather than try to use the next line
BAD_NAMEREF = 202
# ERROR: This function offset comes before the previous offset from the same module
# This hopefully gives some hint about which functions need to be rearranged.
FUNCTION_OUT_OF_ORDER = 203
# ERROR: The line following an explicit by-name marker that does _not_ expect
# a comment -- i.e. VTABLE or GLOBAL -- could not extract the name
NO_SUITABLE_NAME = 204
# ERROR: Two STRING markers have the same module and offset, but the strings
# they annotate are different.
WRONG_STRING = 205
# ERROR: This lineref FUNCTION marker is next to a function declaration or
# forward reference. The correct place for the marker is where the function
# is implemented so we can match with the PDB.
NO_IMPLEMENTATION = 206
@dataclass
class ParserAlert:
code: ParserError
line_number: int
line: Optional[str] = None
def is_warning(self) -> bool:
return self.code.value < ParserError.DECOMP_ERROR_START.value
def is_error(self) -> bool:
return self.code.value >= ParserError.DECOMP_ERROR_START.value

View file

@ -1,144 +0,0 @@
from typing import List, Optional
from .parser import DecompParser
from .error import ParserAlert, ParserError
from .node import ParserSymbol, ParserString
def get_checkorder_filter(module):
"""Return a filter function on implemented functions in the given module"""
return lambda fun: fun.module == module and not fun.lookup_by_name
class DecompLinter:
def __init__(self) -> None:
self.alerts: List[ParserAlert] = []
self._parser = DecompParser()
self._filename: str = ""
self._module: Optional[str] = None
# Set of (str, int) tuples for each module/offset pair seen while scanning.
# This is _not_ reset between files and is intended to report offset reuse
# when scanning the entire directory.
self._offsets_used = set()
# Keep track of strings we have seen. Persists across files.
# Module/offset can be repeated for string markers but the strings must match.
self._strings = {}
def reset(self, full_reset: bool = False):
self.alerts = []
self._parser.reset()
self._filename = ""
self._module = None
if full_reset:
self._offsets_used.clear()
self._strings = {}
def file_is_header(self):
return self._filename.lower().endswith(".h")
def _load_offsets_from_list(self, marker_list: List[ParserSymbol]):
"""Helper for loading (module, offset) tuples while the DecompParser
has them broken up into three different lists."""
for marker in marker_list:
is_string = isinstance(marker, ParserString)
value = (marker.module, marker.offset)
if value in self._offsets_used:
if is_string:
if self._strings[value] != marker.name:
self.alerts.append(
ParserAlert(
code=ParserError.WRONG_STRING,
line_number=marker.line_number,
line=f"0x{marker.offset:08x}, {repr(self._strings[value])} vs. {repr(marker.name)}",
)
)
else:
self.alerts.append(
ParserAlert(
code=ParserError.DUPLICATE_OFFSET,
line_number=marker.line_number,
line=f"0x{marker.offset:08x}",
)
)
else:
self._offsets_used.add(value)
if is_string:
self._strings[value] = marker.name
def _check_function_order(self):
"""Rules:
1. Only markers that are implemented in the file are considered. This means we
only look at markers that are cross-referenced with cvdump output by their line
number. Markers with the lookup_by_name flag set are ignored because we cannot
directly influence their order.
2. Order should be considered for a single module only. If we have multiple
markers for a single function (i.e. for LEGO1 functions linked statically to
ISLE) then the virtual address space will be very different. If we don't check
for one module only, we would incorrectly report that the file is out of order.
"""
if self._module is None:
return
checkorder_filter = get_checkorder_filter(self._module)
last_offset = None
for fun in filter(checkorder_filter, self._parser.functions):
if last_offset is not None:
if fun.offset < last_offset:
self.alerts.append(
ParserAlert(
code=ParserError.FUNCTION_OUT_OF_ORDER,
line_number=fun.line_number,
)
)
last_offset = fun.offset
def _check_offset_uniqueness(self):
self._load_offsets_from_list(self._parser.functions)
self._load_offsets_from_list(self._parser.vtables)
self._load_offsets_from_list(self._parser.variables)
self._load_offsets_from_list(self._parser.strings)
def _check_byname_allowed(self):
if self.file_is_header():
return
for fun in self._parser.functions:
if fun.lookup_by_name:
self.alerts.append(
ParserAlert(
code=ParserError.BYNAME_FUNCTION_IN_CPP,
line_number=fun.line_number,
)
)
def check_lines(self, lines, filename, module=None):
"""`lines` is a generic iterable to allow for testing with a list of strings.
We assume lines has the entire contents of the compilation unit."""
self.reset(False)
self._filename = filename
self._module = module
self._parser.read_lines(lines)
self._parser.finish()
self.alerts = self._parser.alerts[::]
self._check_offset_uniqueness()
if self._module is not None:
self._check_byname_allowed()
if not self.file_is_header():
self._check_function_order()
return len(self.alerts) == 0
def check_file(self, filename, module=None):
"""Convenience method for decomplint cli tool"""
with open(filename, "r", encoding="utf-8") as f:
return self.check_lines(f, filename, module)

View file

@ -1,146 +0,0 @@
import re
from typing import Optional, Tuple
from enum import Enum
class MarkerCategory(Enum):
"""For the purposes of grouping multiple different DecompMarkers together,
assign a rough "category" for the MarkerType values below.
It's really only the function types that have to get folded down, but
we'll do that in a structured way to permit future expansion."""
FUNCTION = 1
VARIABLE = 2
STRING = 3
VTABLE = 4
ADDRESS = 100 # i.e. no comparison required or possible
class MarkerType(Enum):
UNKNOWN = -100
FUNCTION = 1
STUB = 2
SYNTHETIC = 3
TEMPLATE = 4
GLOBAL = 5
VTABLE = 6
STRING = 7
LIBRARY = 8
markerRegex = re.compile(
r"\s*//\s*(?P<type>\w+):\s*(?P<module>\w+)\s+(?P<offset>0x[a-f0-9]+) *(?P<extra>\S.+\S)?",
flags=re.I,
)
markerExactRegex = re.compile(
r"\s*// (?P<type>[A-Z]+): (?P<module>[A-Z0-9]+) (?P<offset>0x[a-f0-9]+)(?: (?P<extra>\S.+\S))?\n?$"
)
class DecompMarker:
def __init__(
self, marker_type: str, module: str, offset: int, extra: Optional[str] = None
) -> None:
try:
self._type = MarkerType[marker_type.upper()]
except KeyError:
self._type = MarkerType.UNKNOWN
# Convert to upper here. A lot of other analysis depends on this name
# being consistent and predictable. If the name is _not_ capitalized
# we will emit a syntax error.
self._module: str = module.upper()
self._offset: int = offset
self._extra: Optional[str] = extra
@property
def type(self) -> MarkerType:
return self._type
@property
def module(self) -> str:
return self._module
@property
def offset(self) -> int:
return self._offset
@property
def extra(self) -> Optional[str]:
return self._extra
@property
def category(self) -> MarkerCategory:
if self.is_vtable():
return MarkerCategory.VTABLE
if self.is_variable():
return MarkerCategory.VARIABLE
if self.is_string():
return MarkerCategory.STRING
# TODO: worth another look if we add more types, but this covers it
if self.is_regular_function() or self.is_explicit_byname():
return MarkerCategory.FUNCTION
return MarkerCategory.ADDRESS
@property
def key(self) -> Tuple[str, str, Optional[str]]:
"""For use with the MarkerDict. To detect/avoid marker collision."""
return (self.category, self.module, self.extra)
def is_regular_function(self) -> bool:
"""Regular function, meaning: not an explicit byname lookup. FUNCTION
markers can be _implicit_ byname.
FUNCTION and STUB markers are (currently) the only heterogenous marker types that
can be lumped together, although the reasons for doing so are a little vague."""
return self._type in (MarkerType.FUNCTION, MarkerType.STUB)
def is_explicit_byname(self) -> bool:
return self._type in (
MarkerType.SYNTHETIC,
MarkerType.TEMPLATE,
MarkerType.LIBRARY,
)
def is_variable(self) -> bool:
return self._type == MarkerType.GLOBAL
def is_synthetic(self) -> bool:
return self._type == MarkerType.SYNTHETIC
def is_template(self) -> bool:
return self._type == MarkerType.TEMPLATE
def is_vtable(self) -> bool:
return self._type == MarkerType.VTABLE
def is_library(self) -> bool:
return self._type == MarkerType.LIBRARY
def is_string(self) -> bool:
return self._type == MarkerType.STRING
def allowed_in_func(self) -> bool:
return self._type in (MarkerType.GLOBAL, MarkerType.STRING)
def match_marker(line: str) -> Optional[DecompMarker]:
match = markerRegex.match(line)
if match is None:
return None
return DecompMarker(
marker_type=match.group("type"),
module=match.group("module"),
offset=int(match.group("offset"), 16),
extra=match.group("extra"),
)
def is_marker_exact(line: str) -> bool:
return markerExactRegex.match(line) is not None

View file

@ -1,63 +0,0 @@
from typing import Optional
from dataclasses import dataclass
from .marker import MarkerType
@dataclass
class ParserSymbol:
"""Exported decomp marker with all information (except the code filename) required to
cross-reference with cvdump data."""
type: MarkerType
line_number: int
module: str
offset: int
name: str
# The parser doesn't (currently) know about the code filename, but if you
# wanted to set it here after the fact, here's the spot.
filename: Optional[str] = None
def should_skip(self) -> bool:
"""The default is to compare any symbols we have"""
return False
def is_nameref(self) -> bool:
"""All symbols default to name lookup"""
return True
@dataclass
class ParserFunction(ParserSymbol):
# We are able to detect the closing line of a function with some reliability.
# This isn't used for anything right now, but perhaps later it will be.
end_line: Optional[int] = None
# All marker types are referenced by name except FUNCTION/STUB. These can also be
# referenced by name, but only if this flag is true.
lookup_by_name: bool = False
def should_skip(self) -> bool:
return self.type == MarkerType.STUB
def is_nameref(self) -> bool:
return (
self.type in (MarkerType.SYNTHETIC, MarkerType.TEMPLATE, MarkerType.LIBRARY)
or self.lookup_by_name
)
@dataclass
class ParserVariable(ParserSymbol):
is_static: bool = False
parent_function: Optional[int] = None
@dataclass
class ParserVtable(ParserSymbol):
base_class: Optional[str] = None
@dataclass
class ParserString(ParserSymbol):
pass

View file

@ -1,556 +0,0 @@
# C++ file parser
from typing import List, Iterable, Iterator, Optional
from enum import Enum
from .util import (
get_class_name,
get_variable_name,
get_synthetic_name,
remove_trailing_comment,
get_string_contents,
sanitize_code_line,
scopeDetectRegex,
)
from .marker import (
DecompMarker,
MarkerCategory,
match_marker,
is_marker_exact,
)
from .node import (
ParserSymbol,
ParserFunction,
ParserVariable,
ParserVtable,
ParserString,
)
from .error import ParserAlert, ParserError
class ReaderState(Enum):
SEARCH = 0
WANT_SIG = 1
IN_FUNC = 2
IN_TEMPLATE = 3
WANT_CURLY = 4
IN_GLOBAL = 5
IN_FUNC_GLOBAL = 6
IN_VTABLE = 7
IN_SYNTHETIC = 8
IN_LIBRARY = 9
DONE = 100
class MarkerDict:
def __init__(self) -> None:
self.markers: dict = {}
def insert(self, marker: DecompMarker) -> bool:
"""Return True if this insert would overwrite"""
if marker.key in self.markers:
return True
self.markers[marker.key] = marker
return False
def query(
self, category: MarkerCategory, module: str, extra: Optional[str] = None
) -> Optional[DecompMarker]:
return self.markers.get((category, module, extra))
def iter(self) -> Iterator[DecompMarker]:
for _, marker in self.markers.items():
yield marker
def empty(self):
self.markers = {}
class CurlyManager:
"""Overly simplified scope manager"""
def __init__(self):
self._stack = []
def reset(self):
self._stack = []
def _pop(self):
"""Pop stack safely"""
try:
self._stack.pop()
except IndexError:
pass
def get_prefix(self, name: Optional[str] = None) -> str:
"""Return the prefix for where we are."""
scopes = [t for t in self._stack if t != "{"]
if len(scopes) == 0:
return name if name is not None else ""
if name is not None and name not in scopes:
scopes.append(name)
return "::".join(scopes)
def read_line(self, raw_line: str):
"""Read a line of code and update the stack."""
line = sanitize_code_line(raw_line)
if (match := scopeDetectRegex.match(line)) is not None:
if not line.endswith(";"):
self._stack.append(match.group("name"))
change = line.count("{") - line.count("}")
if change > 0:
for _ in range(change):
self._stack.append("{")
elif change < 0:
for _ in range(-change):
self._pop()
if len(self._stack) == 0:
return
last = self._stack[-1]
if last != "{":
self._pop()
class DecompParser:
# pylint: disable=too-many-instance-attributes
# Could combine output lists into a single list to get under the limit,
# but not right now
def __init__(self) -> None:
# The lists to be populated as we parse
self._symbols: List[ParserSymbol] = []
self.alerts: List[ParserAlert] = []
self.line_number: int = 0
self.state: ReaderState = ReaderState.SEARCH
self.last_line: str = ""
self.curly = CurlyManager()
# To allow for multiple markers where code is shared across different
# modules, save lists of compatible markers that appear in sequence
self.fun_markers = MarkerDict()
self.var_markers = MarkerDict()
self.tbl_markers = MarkerDict()
# To handle functions that are entirely indented (i.e. those defined
# in class declarations), remember how many whitespace characters
# came before the opening curly brace and match that up at the end.
# This should give us the same or better accuracy for a well-formed file.
# The alternative is counting the curly braces on each line
# but that's probably too cumbersome.
self.curly_indent_stops: int = 0
# For non-synthetic functions, save the line number where the function begins
# (i.e. where we see the curly brace) along with the function signature.
# We will need both when we reach the end of the function.
self.function_start: int = 0
self.function_sig: str = ""
def reset(self):
self._symbols = []
self.alerts = []
self.line_number = 0
self.state = ReaderState.SEARCH
self.last_line = ""
self.fun_markers.empty()
self.var_markers.empty()
self.tbl_markers.empty()
self.curly_indent_stops = 0
self.function_start = 0
self.function_sig = ""
self.curly.reset()
@property
def functions(self) -> List[ParserFunction]:
return [s for s in self._symbols if isinstance(s, ParserFunction)]
@property
def vtables(self) -> List[ParserVtable]:
return [s for s in self._symbols if isinstance(s, ParserVtable)]
@property
def variables(self) -> List[ParserVariable]:
return [s for s in self._symbols if isinstance(s, ParserVariable)]
@property
def strings(self) -> List[ParserString]:
return [s for s in self._symbols if isinstance(s, ParserString)]
def iter_symbols(self, module: Optional[str] = None) -> Iterator[ParserSymbol]:
for s in self._symbols:
if module is None or s.module == module:
yield s
def _recover(self):
"""We hit a syntax error and need to reset temp structures"""
self.state = ReaderState.SEARCH
self.fun_markers.empty()
self.var_markers.empty()
self.tbl_markers.empty()
def _syntax_warning(self, code):
self.alerts.append(
ParserAlert(
line_number=self.line_number,
code=code,
line=self.last_line.strip(),
)
)
def _syntax_error(self, code):
self._syntax_warning(code)
self._recover()
def _function_starts_here(self):
self.function_start = self.line_number
def _function_marker(self, marker: DecompMarker):
if self.fun_markers.insert(marker):
self._syntax_warning(ParserError.DUPLICATE_MODULE)
self.state = ReaderState.WANT_SIG
def _nameref_marker(self, marker: DecompMarker):
"""Functions explicitly referenced by name are set here"""
if self.fun_markers.insert(marker):
self._syntax_warning(ParserError.DUPLICATE_MODULE)
if marker.is_template():
self.state = ReaderState.IN_TEMPLATE
elif marker.is_synthetic():
self.state = ReaderState.IN_SYNTHETIC
else:
self.state = ReaderState.IN_LIBRARY
def _function_done(self, lookup_by_name: bool = False, unexpected: bool = False):
end_line = self.line_number
if unexpected:
# If we missed the end of the previous function, assume it ended
# on the previous line and that whatever we are tracking next
# begins on the current line.
end_line -= 1
for marker in self.fun_markers.iter():
self._symbols.append(
ParserFunction(
type=marker.type,
line_number=self.function_start,
module=marker.module,
offset=marker.offset,
name=self.function_sig,
lookup_by_name=lookup_by_name,
end_line=end_line,
)
)
self.fun_markers.empty()
self.curly_indent_stops = 0
self.state = ReaderState.SEARCH
def _vtable_marker(self, marker: DecompMarker):
if self.tbl_markers.insert(marker):
self._syntax_warning(ParserError.DUPLICATE_MODULE)
self.state = ReaderState.IN_VTABLE
def _vtable_done(self, class_name: str = None):
if class_name is None:
# Best we can do
class_name = self.last_line.strip()
for marker in self.tbl_markers.iter():
self._symbols.append(
ParserVtable(
type=marker.type,
line_number=self.line_number,
module=marker.module,
offset=marker.offset,
name=self.curly.get_prefix(class_name),
base_class=marker.extra,
)
)
self.tbl_markers.empty()
self.state = ReaderState.SEARCH
def _variable_marker(self, marker: DecompMarker):
if self.var_markers.insert(marker):
self._syntax_warning(ParserError.DUPLICATE_MODULE)
if self.state in (ReaderState.IN_FUNC, ReaderState.IN_FUNC_GLOBAL):
self.state = ReaderState.IN_FUNC_GLOBAL
else:
self.state = ReaderState.IN_GLOBAL
def _variable_done(
self, variable_name: Optional[str] = None, string_value: Optional[str] = None
):
if variable_name is None and string_value is None:
self._syntax_error(ParserError.NO_SUITABLE_NAME)
return
for marker in self.var_markers.iter():
if marker.is_string():
self._symbols.append(
ParserString(
type=marker.type,
line_number=self.line_number,
module=marker.module,
offset=marker.offset,
name=string_value,
)
)
else:
parent_function = None
is_static = self.state == ReaderState.IN_FUNC_GLOBAL
# If this is a static variable, we need to get the function
# where it resides so that we can match it up later with the
# mangled names of both variable and function from cvdump.
if is_static:
fun_marker = self.fun_markers.query(
MarkerCategory.FUNCTION, marker.module
)
if fun_marker is None:
self._syntax_warning(ParserError.ORPHANED_STATIC_VARIABLE)
continue
parent_function = fun_marker.offset
self._symbols.append(
ParserVariable(
type=marker.type,
line_number=self.line_number,
module=marker.module,
offset=marker.offset,
name=self.curly.get_prefix(variable_name),
is_static=is_static,
parent_function=parent_function,
)
)
self.var_markers.empty()
if self.state == ReaderState.IN_FUNC_GLOBAL:
self.state = ReaderState.IN_FUNC
else:
self.state = ReaderState.SEARCH
def _handle_marker(self, marker: DecompMarker):
# Cannot handle any markers between function sig and opening curly brace
if self.state == ReaderState.WANT_CURLY:
self._syntax_error(ParserError.UNEXPECTED_MARKER)
return
# If we are inside a function, the only markers we accept are:
# GLOBAL, indicating a static variable
# STRING, indicating a literal string.
# Otherwise we assume that the parser missed the end of the function
# and we have moved on to something else.
# This is unlikely to occur with well-formed code, but
# we can recover easily by just ending the function here.
if self.state == ReaderState.IN_FUNC and not marker.allowed_in_func():
self._syntax_warning(ParserError.MISSED_END_OF_FUNCTION)
self._function_done(unexpected=True)
# TODO: How uncertain are we of detecting the end of a function
# in a clang-formatted file? For now we assume we have missed the
# end if we detect a non-GLOBAL marker while state is IN_FUNC.
# Maybe these cases should be syntax errors instead
if marker.is_regular_function():
if self.state in (
ReaderState.SEARCH,
ReaderState.WANT_SIG,
):
# We will allow multiple offsets if we have just begun
# the code block, but not after we hit the curly brace.
self._function_marker(marker)
else:
self._syntax_error(ParserError.INCOMPATIBLE_MARKER)
elif marker.is_template():
if self.state in (ReaderState.SEARCH, ReaderState.IN_TEMPLATE):
self._nameref_marker(marker)
else:
self._syntax_error(ParserError.INCOMPATIBLE_MARKER)
elif marker.is_synthetic():
if self.state in (ReaderState.SEARCH, ReaderState.IN_SYNTHETIC):
self._nameref_marker(marker)
else:
self._syntax_error(ParserError.INCOMPATIBLE_MARKER)
elif marker.is_library():
if self.state in (ReaderState.SEARCH, ReaderState.IN_LIBRARY):
self._nameref_marker(marker)
else:
self._syntax_error(ParserError.INCOMPATIBLE_MARKER)
# Strings and variables are almost the same thing
elif marker.is_string() or marker.is_variable():
if self.state in (
ReaderState.SEARCH,
ReaderState.IN_GLOBAL,
ReaderState.IN_FUNC,
ReaderState.IN_FUNC_GLOBAL,
):
self._variable_marker(marker)
else:
self._syntax_error(ParserError.INCOMPATIBLE_MARKER)
elif marker.is_vtable():
if self.state in (ReaderState.SEARCH, ReaderState.IN_VTABLE):
self._vtable_marker(marker)
else:
self._syntax_error(ParserError.INCOMPATIBLE_MARKER)
else:
self._syntax_warning(ParserError.BOGUS_MARKER)
def read_line(self, line: str):
if self.state == ReaderState.DONE:
return
self.last_line = line # TODO: Useful or hack for error reporting?
self.line_number += 1
marker = match_marker(line)
if marker is not None:
# TODO: what's the best place for this?
# Does it belong with reading or marker handling?
if not is_marker_exact(self.last_line):
self._syntax_warning(ParserError.BAD_DECOMP_MARKER)
self._handle_marker(marker)
return
self.curly.read_line(line)
line_strip = line.strip()
if self.state in (
ReaderState.IN_SYNTHETIC,
ReaderState.IN_TEMPLATE,
ReaderState.IN_LIBRARY,
):
# Explicit nameref functions provide the function name
# on the next line (in a // comment)
name = get_synthetic_name(line)
if name is None:
self._syntax_error(ParserError.BAD_NAMEREF)
else:
self.function_sig = name
self._function_starts_here()
self._function_done(lookup_by_name=True)
elif self.state == ReaderState.WANT_SIG:
# Ignore blanks on the way to function start or function name
if len(line_strip) == 0:
self._syntax_warning(ParserError.UNEXPECTED_BLANK_LINE)
elif line_strip.startswith("//"):
# If we found a comment, assume implicit lookup-by-name
# function and end here. We know this is not a decomp marker
# because it would have been handled already.
self.function_sig = get_synthetic_name(line)
self._function_starts_here()
self._function_done(lookup_by_name=True)
elif line_strip == "{":
# We missed the function signature but we can recover from this
self.function_sig = "(unknown)"
self._function_starts_here()
self._syntax_warning(ParserError.MISSED_START_OF_FUNCTION)
self.state = ReaderState.IN_FUNC
else:
# Inline functions may end with a comment. Strip that out
# to help parsing.
self.function_sig = remove_trailing_comment(line_strip)
# Now check to see if the opening curly bracket is on the
# same line. clang-format should prevent this (BraceWrapping)
# but it is easy to detect.
# If the entire function is on one line, handle that too.
if self.function_sig.endswith("{"):
self._function_starts_here()
self.state = ReaderState.IN_FUNC
elif self.function_sig.endswith("}") or self.function_sig.endswith(
"};"
):
self._function_starts_here()
self._function_done()
elif self.function_sig.endswith(");"):
# Detect forward reference or declaration
self._syntax_error(ParserError.NO_IMPLEMENTATION)
else:
self.state = ReaderState.WANT_CURLY
elif self.state == ReaderState.WANT_CURLY:
if line_strip == "{":
self.curly_indent_stops = line.index("{")
self._function_starts_here()
self.state = ReaderState.IN_FUNC
elif self.state == ReaderState.IN_FUNC:
if line_strip.startswith("}") and line[self.curly_indent_stops] == "}":
self._function_done()
elif self.state in (ReaderState.IN_GLOBAL, ReaderState.IN_FUNC_GLOBAL):
# TODO: Known problem that an error here will cause us to abandon a
# function we have already parsed if state == IN_FUNC_GLOBAL.
# However, we are not tolerant of _any_ syntax problems in our
# CI actions, so the solution is to just fix the invalid marker.
variable_name = None
global_markers_queued = any(
m.is_variable() for m in self.var_markers.iter()
)
if len(line_strip) == 0:
self._syntax_warning(ParserError.UNEXPECTED_BLANK_LINE)
return
if global_markers_queued:
# Not the greatest solution, but a consequence of combining GLOBAL and
# STRING markers together. If the marker precedes a return statement, it is
# valid for a STRING marker to be here, but not a GLOBAL. We need to look
# ahead and tell whether this *would* fail.
if line_strip.startswith("return"):
self._syntax_error(ParserError.GLOBAL_NOT_VARIABLE)
return
if line_strip.startswith("//"):
# If we found a comment, assume implicit lookup-by-name
# function and end here. We know this is not a decomp marker
# because it would have been handled already.
variable_name = get_synthetic_name(line)
else:
variable_name = get_variable_name(line)
string_name = get_string_contents(line)
self._variable_done(variable_name, string_name)
elif self.state == ReaderState.IN_VTABLE:
vtable_class = get_class_name(line)
if vtable_class is not None:
self._vtable_done(class_name=vtable_class)
def read_lines(self, lines: Iterable):
for line in lines:
self.read_line(line)
def finish(self):
if self.state != ReaderState.SEARCH:
self._syntax_warning(ParserError.UNEXPECTED_END_OF_FILE)
self.state = ReaderState.DONE

View file

@ -1,141 +0,0 @@
# C++ Parser utility functions and data structures
import re
from typing import Optional
from ast import literal_eval
# The goal here is to just read whatever is on the next line, so some
# flexibility in the formatting seems OK
templateCommentRegex = re.compile(r"\s*//\s+(.*)")
# To remove any comment (//) or block comment (/*) and its leading spaces
# from the end of a code line
trailingCommentRegex = re.compile(r"(\s*(?://|/\*).*)$")
# Get char contents, ignore escape characters
singleQuoteRegex = re.compile(r"('(?:[^\'\\]|\\.)')")
# Match contents of block comment on one line
blockCommentRegex = re.compile(r"(/\*.*?\*/)")
# Match contents of single comment on one line
regularCommentRegex = re.compile(r"(//.*)")
# Get string contents, ignore escape characters that might interfere
doubleQuoteRegex = re.compile(r"(\"(?:[^\"\\]|\\.)*\")")
# Detect a line that would cause us to enter a new scope
scopeDetectRegex = re.compile(r"(?:class|struct|namespace) (?P<name>\w+).*(?:{)?")
def get_synthetic_name(line: str) -> Optional[str]:
"""Synthetic names appear on a single line comment on the line after the marker.
If that's not what we have, return None"""
template_match = templateCommentRegex.match(line)
if template_match is not None:
return template_match.group(1)
return None
def sanitize_code_line(line: str) -> str:
"""Helper for scope manager. Removes sections from a code line
that would cause us to incorrectly detect curly brackets.
This is a very naive implementation and fails entirely on multi-line
strings or comments."""
line = singleQuoteRegex.sub("''", line)
line = doubleQuoteRegex.sub('""', line)
line = blockCommentRegex.sub("", line)
line = regularCommentRegex.sub("", line)
return line.strip()
def remove_trailing_comment(line: str) -> str:
return trailingCommentRegex.sub("", line)
def is_blank_or_comment(line: str) -> bool:
"""Helper to read ahead after the offset comment is matched.
There could be blank lines or other comments before the
function signature, and we want to skip those."""
line_strip = line.strip()
return (
len(line_strip) == 0
or line_strip.startswith("//")
or line_strip.startswith("/*")
or line_strip.endswith("*/")
)
template_regex = re.compile(r"<(?P<type>[\w]+)\s*(?P<asterisks>\*+)?\s*>")
class_decl_regex = re.compile(
r"\s*(?:\/\/)?\s*(?:class|struct) ((?:\w+(?:<.+>)?(?:::)?)+)"
)
def template_replace(match: re.Match) -> str:
(type_name, asterisks) = match.groups()
if asterisks is None:
return f"<{type_name}>"
return f"<{type_name} {asterisks}>"
def fix_template_type(class_name: str) -> str:
"""For template classes, we should reformat the class name so it matches
the output from cvdump: one space between the template type and any asterisks
if it is a pointer type."""
if "<" not in class_name:
return class_name
return template_regex.sub(template_replace, class_name)
def get_class_name(line: str) -> Optional[str]:
"""For VTABLE markers, extract the class name from the code line or comment
where it appears."""
match = class_decl_regex.match(line)
if match is not None:
return fix_template_type(match.group(1))
return None
global_regex = re.compile(r"(?P<name>(?:\w+::)*g_\w+)")
less_strict_global_regex = re.compile(r"(?P<name>(?:\w+::)*\w+)(?:\)\(|\[.*|\s*=.*|;)")
def get_variable_name(line: str) -> Optional[str]:
"""Grab the name of the variable annotated with the GLOBAL marker.
Correct syntax would have the variable start with the prefix "g_"
but we will try to match regardless."""
if (match := global_regex.search(line)) is not None:
return match.group("name")
if (match := less_strict_global_regex.search(line)) is not None:
return match.group("name")
return None
def get_string_contents(line: str) -> Optional[str]:
"""Return the first C string seen on this line.
We have to unescape the string, and a simple way to do that is to use
python's ast.literal_eval. I'm sure there are many pitfalls to doing
it this way, but hopefully the regex will ensure reasonably sane input."""
try:
if (match := doubleQuoteRegex.search(line)) is not None:
return literal_eval(match.group(1))
# pylint: disable=broad-exception-caught
# No way to predict what kind of exception could occur.
except Exception:
pass
return None

View file

@ -1,13 +0,0 @@
"""Types shared by other modules"""
from enum import Enum
class SymbolType(Enum):
"""Broadly tells us what kind of comparison is required for this symbol."""
FUNCTION = 1
DATA = 2
POINTER = 3
STRING = 4
VTABLE = 5
FLOAT = 6

View file

@ -1,308 +0,0 @@
import os
import sys
from datetime import datetime
import logging
import colorama
def print_combined_diff(udiff, plain: bool = False, show_both: bool = False):
if udiff is None:
return
# We don't know how long the address string will be ahead of time.
# Set this value for each address to try to line things up.
padding_size = 0
for slug, subgroups in udiff:
if plain:
print("---")
print("+++")
print(slug)
else:
print(f"{colorama.Fore.RED}---")
print(f"{colorama.Fore.GREEN}+++")
print(f"{colorama.Fore.BLUE}{slug}")
print(colorama.Style.RESET_ALL, end="")
for subgroup in subgroups:
equal = subgroup.get("both") is not None
if equal:
for orig_addr, line, recomp_addr in subgroup["both"]:
padding_size = max(padding_size, len(orig_addr))
if show_both:
print(f"{orig_addr} / {recomp_addr} : {line}")
else:
print(f"{orig_addr} : {line}")
else:
for orig_addr, line in subgroup["orig"]:
padding_size = max(padding_size, len(orig_addr))
addr_prefix = (
f"{orig_addr} / {'':{padding_size}}" if show_both else orig_addr
)
if plain:
print(f"{addr_prefix} : -{line}")
else:
print(
f"{addr_prefix} : {colorama.Fore.RED}-{line}{colorama.Style.RESET_ALL}"
)
for recomp_addr, line in subgroup["recomp"]:
padding_size = max(padding_size, len(recomp_addr))
addr_prefix = (
f"{'':{padding_size}} / {recomp_addr}"
if show_both
else " " * padding_size
)
if plain:
print(f"{addr_prefix} : +{line}")
else:
print(
f"{addr_prefix} : {colorama.Fore.GREEN}+{line}{colorama.Style.RESET_ALL}"
)
# Newline between each diff subgroup.
print()
def print_diff(udiff, plain):
"""Print diff in difflib.unified_diff format."""
if udiff is None:
return False
has_diff = False
for line in udiff:
has_diff = True
color = ""
if line.startswith("++") or line.startswith("@@") or line.startswith("--"):
# Skip unneeded parts of the diff for the brief view
continue
# Work out color if we are printing color
if not plain:
if line.startswith("+"):
color = colorama.Fore.GREEN
elif line.startswith("-"):
color = colorama.Fore.RED
print(color + line)
# Reset color if we're printing in color
if not plain:
print(colorama.Style.RESET_ALL, end="")
return has_diff
def get_percent_color(value: float) -> str:
"""Return colorama ANSI escape character for the given decimal value."""
if value == 1.0:
return colorama.Fore.GREEN
if value > 0.8:
return colorama.Fore.YELLOW
return colorama.Fore.RED
def percent_string(
ratio: float, is_effective: bool = False, is_plain: bool = False
) -> str:
"""Helper to construct a percentage string from the given ratio.
If is_effective (i.e. effective match), indicate that with the asterisk.
If is_plain, don't use colorama ANSI codes."""
percenttext = f"{(ratio * 100):.2f}%"
effective_star = "*" if is_effective else ""
if is_plain:
return percenttext + effective_star
return "".join(
[
get_percent_color(ratio),
percenttext,
colorama.Fore.RED if is_effective else "",
effective_star,
colorama.Style.RESET_ALL,
]
)
def diff_json_display(show_both_addrs: bool = False, is_plain: bool = False):
"""Generate a function that will display the diff according to
the reccmp display preferences."""
def formatter(orig_addr, saved, new) -> str:
old_pct = "new"
new_pct = "gone"
name = ""
recomp_addr = "n/a"
if new is not None:
new_pct = (
"stub"
if new.get("stub", False)
else percent_string(
new["matching"], new.get("effective", False), is_plain
)
)
# Prefer the current name of this function if we have it.
# We are using the original address as the key.
# A function being renamed is not of interest here.
name = new.get("name", "")
recomp_addr = new.get("recomp", "n/a")
if saved is not None:
old_pct = (
"stub"
if saved.get("stub", False)
else percent_string(
saved["matching"], saved.get("effective", False), is_plain
)
)
if name == "":
name = saved.get("name", "")
if show_both_addrs:
addr_string = f"{orig_addr} / {recomp_addr:10}"
else:
addr_string = orig_addr
# The ANSI codes from colorama counted towards string length,
# so displaying this as an ascii-like spreadsheet
# (using f-string formatting) would take some effort.
return f"{addr_string} - {name} ({old_pct} -> {new_pct})"
return formatter
def diff_json(
saved_data,
new_data,
orig_file: str,
show_both_addrs: bool = False,
is_plain: bool = False,
):
"""Using a saved copy of the diff summary and the current data, print a
report showing which functions/symbols have changed match percentage."""
# Don't try to diff a report generated for a different binary file
base_file = os.path.basename(orig_file).lower()
if saved_data.get("file") != base_file:
logging.getLogger().error(
"Diff report for '%s' does not match current file '%s'",
saved_data.get("file"),
base_file,
)
return
if "timestamp" in saved_data:
now = datetime.now().replace(microsecond=0)
then = datetime.fromtimestamp(saved_data["timestamp"]).replace(microsecond=0)
print(
" ".join(
[
"Saved diff report generated",
then.strftime("%B %d %Y, %H:%M:%S"),
f"({str(now - then)} ago)",
]
)
)
print()
# Convert to dict, using orig_addr as key
saved_invert = {obj["address"]: obj for obj in saved_data["data"]}
new_invert = {obj["address"]: obj for obj in new_data}
all_addrs = set(saved_invert.keys()).union(new_invert.keys())
# Put all the information in one place so we can decide how each item changed.
combined = {
addr: (
saved_invert.get(addr),
new_invert.get(addr),
)
for addr in sorted(all_addrs)
}
# The criteria for diff judgement is in these dict comprehensions:
# Any function not in the saved file
new_functions = {
key: (saved, new) for key, (saved, new) in combined.items() if saved is None
}
# Any function now missing from the saved file
# or a non-stub -> stub conversion
dropped_functions = {
key: (saved, new)
for key, (saved, new) in combined.items()
if new is None
or (
new is not None
and saved is not None
and new.get("stub", False)
and not saved.get("stub", False)
)
}
# TODO: move these two into functions if the assessment gets more complex
# Any function with increased match percentage
# or stub -> non-stub conversion
improved_functions = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
and (
new["matching"] > saved["matching"]
or (not new.get("stub", False) and saved.get("stub", False))
)
}
# Any non-stub function with decreased match percentage
degraded_functions = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
and new["matching"] < saved["matching"]
and not saved.get("stub")
and not new.get("stub")
}
# Any function with former or current "effective" match
entropy_functions = {
key: (saved, new)
for key, (saved, new) in combined.items()
if saved is not None
and new is not None
and new["matching"] == 1.0
and saved["matching"] == 1.0
and new.get("effective", False) != saved.get("effective", False)
}
get_diff_str = diff_json_display(show_both_addrs, is_plain)
for diff_name, diff_dict in [
("New", new_functions),
("Increased", improved_functions),
("Decreased", degraded_functions),
("Dropped", dropped_functions),
("Compiler entropy", entropy_functions),
]:
if len(diff_dict) == 0:
continue
print(f"{diff_name} ({len(diff_dict)}):")
for addr, (saved, new) in diff_dict.items():
print(get_diff_str(addr, saved, new))
print()
def get_file_in_script_dir(fn):
return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn)

View file

@ -1,11 +0,0 @@
from setuptools import setup, find_packages
setup(
name="isledecomp",
version="0.1.0",
description="Python tools for the isledecomp project",
packages=find_packages(),
tests_require=["pytest"],
include_package_data=True,
package_data={"isledecomp.lib": ["*.exe", "*.dll"]},
)

View file

@ -1,3 +0,0 @@
def pytest_addoption(parser):
"""Allow the option to run tests against the original LEGO1.DLL."""
parser.addoption("--lego1", action="store", help="Path to LEGO1.DLL")

View file

@ -1,30 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// A very simple class
// VTABLE: TEST 0x1001002
class TestClass {
public:
TestClass();
virtual ~TestClass() override;
virtual MxResult Tickle() override; // vtable+08
// FUNCTION: TEST 0x12345678
inline const char* ClassName() const // vtable+0c
{
// 0xabcd1234
return "TestClass";
}
// FUNCTION: TEST 0xdeadbeef
inline MxBool IsA(const char* name) const override // vtable+10
{
return !strcmp(name, TestClass::ClassName());
}
private:
int m_hello;
int m_hiThere;
};

View file

@ -1,22 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// A very simple well-formed code file
// FUNCTION: TEST 0x1234
void function01()
{
// TODO
}
// FUNCTION: TEST 0x2345
void function02()
{
// TODO
}
// FUNCTION: TEST 0x3456
void function03()
{
// TODO
}

View file

@ -1,14 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// Global variables inside and outside of functions
// GLOBAL: TEST 0x1000
const char *g_message = "test";
// FUNCTION: TEST 0x1234
void function01()
{
// GLOBAL: TEST 0x5555
static int g_hello = 123;
}

View file

@ -1,8 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// FUNCTION: TEST 0x10000001
inline const char* OneLineWithComment() const { return "MxDSObject"; }; // hi there
// FUNCTION: TEST 0x10000002
inline const char* OneLine() const { return "MxDSObject"; };

View file

@ -1,16 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
#include <stdio.h>
int no_offset_comment()
{
static int dummy = 123;
return -1;
}
// FUNCTION: TEST 0xdeadbeef
void regular_ole_function()
{
printf("hi there");
}

View file

@ -1,25 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// Handling multiple offset markers
// FUNCTION: TEST 0x1234
// FUNCTION: HELLO 0x5555
void different_modules()
{
// TODO
}
// FUNCTION: TEST 0x2345
// FUNCTION: TEST 0x1234
void same_module()
{
// TODO
}
// FUNCTION: TEST 0x2002
// FUNCTION: test 0x1001
void same_case_insensitive()
{
// TODO
}

View file

@ -1,12 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// FUNCTION: TEST 0x1234
void short_function() { static char* msg = "oneliner"; }
// FUNCTION: TEST 0x5555
void function_after_one_liner()
{
// This function comes after the previous that is on a single line.
// Do we report the offset for this one correctly?
}

View file

@ -1,20 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// FUNCTION: TEST 0x1001
void function_order01()
{
// TODO
}
// FUNCTION: TEST 0x1003
void function_order03()
{
// TODO
}
// FUNCTION: TEST 0x1002
void function_order02()
{
// TODO
}

View file

@ -1,23 +0,0 @@
// Sample for python unit tests
// Not part of the decomp
// While it's reasonable to expect a well-formed file (and clang-format
// will make sure we get one), this will put the parser through its paces.
// FUNCTION: TEST 0x1234
void curly_with_spaces()
{
static char* msg = "hello";
}
// FUNCTION: TEST 0x5555
void weird_closing_curly()
{
int x = 123; }
// FUNCTION: HELLO 0x5656
void bad_indenting() {
if (0)
{
int y = 5;
}}

View file

@ -1,82 +0,0 @@
"""Testing compare database behavior, particularly matching"""
import pytest
from isledecomp.compare.db import CompareDb
@pytest.fixture(name="db")
def fixture_db():
return CompareDb()
def test_ignore_recomp_collision(db):
"""Duplicate recomp addresses are ignored"""
db.set_recomp_symbol(0x1234, None, "hello", None, 100)
db.set_recomp_symbol(0x1234, None, "alias_for_hello", None, 100)
syms = db.get_all()
assert len(syms) == 1
def test_orig_collision(db):
"""Don't match if the original address is not unique"""
db.set_recomp_symbol(0x1234, None, "hello", None, 100)
assert db.match_function(0x5555, "hello") is True
# Second run on same address fails
assert db.match_function(0x5555, "hello") is False
# Call set_pair directly without wrapper
assert db.set_pair(0x5555, 0x1234) is False
def test_name_match(db):
db.set_recomp_symbol(0x1234, None, "hello", None, 100)
assert db.match_function(0x5555, "hello") is True
match = db.get_by_orig(0x5555)
assert match.name == "hello"
assert match.recomp_addr == 0x1234
def test_match_decorated(db):
"""Should match using decorated name even though regular name is null"""
db.set_recomp_symbol(0x1234, None, None, "?_hello", 100)
assert db.match_function(0x5555, "?_hello") is True
match = db.get_by_orig(0x5555)
assert match is not None
def test_duplicate_name(db):
"""If recomp name is not unique, match only one row"""
db.set_recomp_symbol(0x100, None, "_Construct", None, 100)
db.set_recomp_symbol(0x200, None, "_Construct", None, 100)
db.set_recomp_symbol(0x300, None, "_Construct", None, 100)
db.match_function(0x5555, "_Construct")
matches = db.get_matches()
# We aren't testing _which_ one would be matched, just that only one _was_ matched
assert len(matches) == 1
def test_static_variable_match(db):
"""Set up a situation where we can match a static function variable, then match it."""
# We need a matched function to start with.
db.set_recomp_symbol(0x1234, None, "Isle::Tick", "?Tick@IsleApp@@QAEXH@Z", 100)
db.match_function(0x5555, "Isle::Tick")
# Decorated variable name from PDB.
db.set_recomp_symbol(
0x2000, None, None, "?g_startupDelay@?1??Tick@IsleApp@@QAEXH@Z@4HA", 4
)
# Provide variable name and orig function address from decomp markers
assert db.match_static_variable(0xBEEF, "g_startupDelay", 0x5555) is True
def test_match_options_bool(db):
"""Test handling of boolean match options"""
# You don't actually need an existing orig addr for this.
assert db.get_match_options(0x1234) == {}
db.mark_stub(0x1234)
assert "stub" in db.get_match_options(0x1234)

View file

@ -1,73 +0,0 @@
# nyuk nyuk nyuk
import pytest
from isledecomp.parser.parser import CurlyManager
from isledecomp.parser.util import sanitize_code_line
@pytest.fixture(name="curly")
def fixture_curly():
return CurlyManager()
def test_simple(curly):
curly.read_line("namespace Test {")
assert curly.get_prefix() == "Test"
curly.read_line("}")
assert curly.get_prefix() == ""
def test_oneliner(curly):
"""Should not go down into a scope for a class forward reference"""
curly.read_line("class LegoEntity;")
assert curly.get_prefix() == ""
# Now make sure that we still would not consider that class name
# even after reading the opening curly brace
curly.read_line("if (true) {")
assert curly.get_prefix() == ""
def test_ignore_comments(curly):
curly.read_line("namespace Test {")
curly.read_line("// }")
assert curly.get_prefix() == "Test"
@pytest.mark.xfail(reason="todo: need a real lexer")
def test_ignore_multiline_comments(curly):
curly.read_line("namespace Test {")
curly.read_line("/*")
curly.read_line("}")
curly.read_line("*/")
assert curly.get_prefix() == "Test"
curly.read_line("}")
assert curly.get_prefix() == ""
def test_nested(curly):
curly.read_line("namespace Test {")
curly.read_line("namespace Foo {")
assert curly.get_prefix() == "Test::Foo"
curly.read_line("}")
assert curly.get_prefix() == "Test"
sanitize_cases = [
("", ""),
(" ", ""),
("{", "{"),
("// comments {", ""),
("{ // why comment here", "{"),
("/* comments */ {", "{"),
('"curly in a string {"', '""'),
('if (!strcmp("hello { there }", g_test)) {', 'if (!strcmp("", g_test)) {'),
("'{'", "''"),
("weird_function('\"', hello, '\"')", "weird_function('', hello, '')"),
]
@pytest.mark.parametrize("start, end", sanitize_cases)
def test_sanitize(start: str, end: str):
"""Make sure that we can remove curly braces in places where they should
not be considered as part of the semantic structure of the file.
i.e. inside strings or chars, and inside comments"""
assert sanitize_code_line(start) == end

View file

@ -1,59 +0,0 @@
import pytest
from isledecomp.cvdump.types import (
scalar_type_size,
scalar_type_pointer,
scalar_type_signed,
)
# These are all the types seen in the cvdump.
# We have char, short, int, long, long long, float, and double all represented
# in both signed and unsigned.
# We can also identify a 4 byte pointer with the T_32 prefix.
# The type T_VOID is used to designate a function's return type.
# T_NOTYPE is specified as the type of "this" for a static function in a class.
# For reference: https://github.com/microsoft/microsoft-pdb/blob/master/include/cvinfo.h
# fmt: off
# Fields are: type_name, size, is_signed, is_pointer
type_check_cases = (
("T_32PINT4", 4, False, True),
("T_32PLONG", 4, False, True),
("T_32PRCHAR", 4, False, True),
("T_32PREAL32", 4, False, True),
("T_32PUCHAR", 4, False, True),
("T_32PUINT4", 4, False, True),
("T_32PULONG", 4, False, True),
("T_32PUSHORT", 4, False, True),
("T_32PVOID", 4, False, True),
("T_CHAR", 1, True, False),
("T_INT4", 4, True, False),
("T_LONG", 4, True, False),
("T_QUAD", 8, True, False),
("T_RCHAR", 1, True, False),
("T_REAL32", 4, True, False),
("T_REAL64", 8, True, False),
("T_SHORT", 2, True, False),
("T_UCHAR", 1, False, False),
("T_UINT4", 4, False, False),
("T_ULONG", 4, False, False),
("T_UQUAD", 8, False, False),
("T_USHORT", 2, False, False),
("T_WCHAR", 2, False, False),
)
# fmt: on
@pytest.mark.parametrize("type_name, size, _, __", type_check_cases)
def test_scalar_size(type_name: str, size: int, _, __):
assert scalar_type_size(type_name) == size
@pytest.mark.parametrize("type_name, _, is_signed, __", type_check_cases)
def test_scalar_signed(type_name: str, _, is_signed: bool, __):
assert scalar_type_signed(type_name) == is_signed
@pytest.mark.parametrize("type_name, _, __, is_pointer", type_check_cases)
def test_scalar_pointer(type_name: str, _, __, is_pointer: bool):
assert scalar_type_pointer(type_name) == is_pointer

View file

@ -1,38 +0,0 @@
"""Test Cvdump SYMBOLS parser, reading function stack/params"""
from isledecomp.cvdump.symbols import CvdumpSymbolsParser
PROC_WITH_BLOC = """
(000638) S_GPROC32: [0001:000C6135], Cb: 00000361, Type: 0x10ED, RegistrationBook::ReadyWorld
Parent: 00000000, End: 00000760, Next: 00000000
Debug start: 0000000C, Debug end: 0000035C
Flags: Frame Ptr Present
(00067C) S_BPREL32: [FFFFFFD0], Type: 0x10EC, this
(000690) S_BPREL32: [FFFFFFDC], Type: 0x10F5, checkmarkBuffer
(0006AC) S_BPREL32: [FFFFFFE8], Type: 0x10F6, letterBuffer
(0006C8) S_BPREL32: [FFFFFFF4], Type: T_SHORT(0011), i
(0006D8) S_BPREL32: [FFFFFFF8], Type: 0x10F8, players
(0006EC) S_BPREL32: [FFFFFFFC], Type: 0x1044, gameState
(000704) S_BLOCK32: [0001:000C624F], Cb: 000001DA,
Parent: 00000638, End: 0000072C
(00071C) S_BPREL32: [FFFFFFD8], Type: T_SHORT(0011), j
(00072C) S_END
(000730) S_BLOCK32: [0001:000C6448], Cb: 00000032,
Parent: 00000638, End: 0000075C
(000748) S_BPREL32: [FFFFFFD4], Type: 0x10FA, infoman
(00075C) S_END
(000760) S_END
"""
def test_sblock32():
"""S_END has double duty as marking the end of a function (S_GPROC32)
and a scope block (S_BLOCK32). Make sure we can distinguish between
the two and not end a function early."""
parser = CvdumpSymbolsParser()
for line in PROC_WITH_BLOC.split("\n"):
parser.read_line(line)
# Make sure we can read the proc and all its stack references
assert len(parser.symbols) == 1
assert len(parser.symbols[0].stack_symbols) == 8

View file

@ -1,705 +0,0 @@
"""Specifically testing the Cvdump TYPES parser
and type dependency tree walker."""
import pytest
from isledecomp.cvdump.types import (
CvdumpTypesParser,
CvdumpKeyError,
CvdumpIntegrityError,
FieldListItem,
VirtualBaseClass,
VirtualBasePointer,
)
TEST_LINES = """
0x1018 : Length = 18, Leaf = 0x1201 LF_ARGLIST argument count = 3
list[0] = 0x100D
list[1] = 0x1016
list[2] = 0x1017
0x1019 : Length = 14, Leaf = 0x1008 LF_PROCEDURE
Return type = T_LONG(0012), Call type = C Near
Func attr = none
# Parms = 3, Arg list type = 0x1018
0x101e : Length = 26, Leaf = 0x1009 LF_MFUNCTION
Return type = T_CHAR(0010), Class type = 0x101A, This type = 0x101B,
Call type = ThisCall, Func attr = none
Parms = 2, Arg list type = 0x101d, This adjust = 0
0x1028 : Length = 10, Leaf = 0x1001 LF_MODIFIER
const, modifies type T_REAL32(0040)
0x103b : Length = 14, Leaf = 0x1503 LF_ARRAY
Element type = T_REAL32(0040)
Index type = T_SHORT(0011)
length = 16
Name =
0x103c : Length = 14, Leaf = 0x1503 LF_ARRAY
Element type = 0x103B
Index type = T_SHORT(0011)
length = 64
Name =
0x10e0 : Length = 86, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_MEMBER, public, type = T_REAL32(0040), offset = 0
member name = 'x'
list[1] = LF_MEMBER, public, type = T_REAL32(0040), offset = 0
member name = 'dvX'
list[2] = LF_MEMBER, public, type = T_REAL32(0040), offset = 4
member name = 'y'
list[3] = LF_MEMBER, public, type = T_REAL32(0040), offset = 4
member name = 'dvY'
list[4] = LF_MEMBER, public, type = T_REAL32(0040), offset = 8
member name = 'z'
list[5] = LF_MEMBER, public, type = T_REAL32(0040), offset = 8
member name = 'dvZ'
0x10e1 : Length = 34, Leaf = 0x1505 LF_STRUCTURE
# members = 6, field list type 0x10e0,
Derivation list type 0x0000, VT shape type 0x0000
Size = 12, class name = _D3DVECTOR, UDT(0x000010e1)
0x10e4 : Length = 14, Leaf = 0x1503 LF_ARRAY
Element type = T_UCHAR(0020)
Index type = T_SHORT(0011)
length = 8
Name =
0x10ea : Length = 14, Leaf = 0x1503 LF_ARRAY
Element type = 0x1028
Index type = T_SHORT(0011)
length = 12
Name =
0x11f0 : Length = 30, Leaf = 0x1504 LF_CLASS
# members = 0, field list type 0x0000, FORWARD REF,
Derivation list type 0x0000, VT shape type 0x0000
Size = 0, class name = MxRect32, UDT(0x00001214)
0x11f2 : Length = 10, Leaf = 0x1001 LF_MODIFIER
const, modifies type 0x11F0
0x1213 : Length = 530, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_METHOD, count = 5, list = 0x1203, name = 'MxRect32'
list[1] = LF_ONEMETHOD, public, VANILLA, index = 0x1205, name = 'operator='
list[2] = LF_ONEMETHOD, public, VANILLA, index = 0x11F5, name = 'Intersect'
list[3] = LF_ONEMETHOD, public, VANILLA, index = 0x1207, name = 'SetPoint'
list[4] = LF_ONEMETHOD, public, VANILLA, index = 0x1207, name = 'AddPoint'
list[5] = LF_ONEMETHOD, public, VANILLA, index = 0x1207, name = 'SubtractPoint'
list[6] = LF_ONEMETHOD, public, VANILLA, index = 0x11F5, name = 'UpdateBounds'
list[7] = LF_ONEMETHOD, public, VANILLA, index = 0x1209, name = 'IsValid'
list[8] = LF_ONEMETHOD, public, VANILLA, index = 0x120A, name = 'IntersectsWith'
list[9] = LF_ONEMETHOD, public, VANILLA, index = 0x120B, name = 'GetWidth'
list[10] = LF_ONEMETHOD, public, VANILLA, index = 0x120B, name = 'GetHeight'
list[11] = LF_ONEMETHOD, public, VANILLA, index = 0x120C, name = 'GetPoint'
list[12] = LF_ONEMETHOD, public, VANILLA, index = 0x120D, name = 'GetSize'
list[13] = LF_ONEMETHOD, public, VANILLA, index = 0x120B, name = 'GetLeft'
list[14] = LF_ONEMETHOD, public, VANILLA, index = 0x120B, name = 'GetTop'
list[15] = LF_ONEMETHOD, public, VANILLA, index = 0x120B, name = 'GetRight'
list[16] = LF_ONEMETHOD, public, VANILLA, index = 0x120B, name = 'GetBottom'
list[17] = LF_ONEMETHOD, public, VANILLA, index = 0x120E, name = 'SetLeft'
list[18] = LF_ONEMETHOD, public, VANILLA, index = 0x120E, name = 'SetTop'
list[19] = LF_ONEMETHOD, public, VANILLA, index = 0x120E, name = 'SetRight'
list[20] = LF_ONEMETHOD, public, VANILLA, index = 0x120E, name = 'SetBottom'
list[21] = LF_METHOD, count = 3, list = 0x1211, name = 'CopyFrom'
list[22] = LF_ONEMETHOD, private, STATIC, index = 0x1212, name = 'Min'
list[23] = LF_ONEMETHOD, private, STATIC, index = 0x1212, name = 'Max'
list[24] = LF_MEMBER, private, type = T_INT4(0074), offset = 0
member name = 'm_left'
list[25] = LF_MEMBER, private, type = T_INT4(0074), offset = 4
member name = 'm_top'
list[26] = LF_MEMBER, private, type = T_INT4(0074), offset = 8
member name = 'm_right'
list[27] = LF_MEMBER, private, type = T_INT4(0074), offset = 12
member name = 'm_bottom'
0x1214 : Length = 30, Leaf = 0x1504 LF_CLASS
# members = 34, field list type 0x1213, CONSTRUCTOR, OVERLOAD,
Derivation list type 0x0000, VT shape type 0x0000
Size = 16, class name = MxRect32, UDT(0x00001214)
0x1220 : Length = 30, Leaf = 0x1504 LF_CLASS
# members = 0, field list type 0x0000, FORWARD REF,
Derivation list type 0x0000, VT shape type 0x0000
Size = 0, class name = MxCore, UDT(0x00004060)
0x14db : Length = 30, Leaf = 0x1504 LF_CLASS
# members = 0, field list type 0x0000, FORWARD REF,
Derivation list type 0x0000, VT shape type 0x0000
Size = 0, class name = MxString, UDT(0x00004db6)
0x19b0 : Length = 34, Leaf = 0x1505 LF_STRUCTURE
# members = 0, field list type 0x0000, FORWARD REF,
Derivation list type 0x0000, VT shape type 0x0000
Size = 0, class name = ROIColorAlias, UDT(0x00002a76)
0x19b1 : Length = 14, Leaf = 0x1503 LF_ARRAY
Element type = 0x19B0
Index type = T_SHORT(0011)
length = 440
Name =
0x2339 : Length = 26, Leaf = 0x1506 LF_UNION
# members = 0, field list type 0x0000, FORWARD REF, Size = 0 ,class name = FlagBitfield, UDT(0x00002e85)
0x2e85 : Length = 26, Leaf = 0x1506 LF_UNION
# members = 8, field list type 0x2e84, Size = 1 ,class name = FlagBitfield, UDT(0x00002e85)
0x2a75 : Length = 98, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_MEMBER, public, type = T_32PRCHAR(0470), offset = 0
member name = 'm_name'
list[1] = LF_MEMBER, public, type = T_INT4(0074), offset = 4
member name = 'm_red'
list[2] = LF_MEMBER, public, type = T_INT4(0074), offset = 8
member name = 'm_green'
list[3] = LF_MEMBER, public, type = T_INT4(0074), offset = 12
member name = 'm_blue'
list[4] = LF_MEMBER, public, type = T_INT4(0074), offset = 16
member name = 'm_unk0x10'
0x2a76 : Length = 34, Leaf = 0x1505 LF_STRUCTURE
# members = 5, field list type 0x2a75,
Derivation list type 0x0000, VT shape type 0x0000
Size = 20, class name = ROIColorAlias, UDT(0x00002a76)
0x22d4 : Length = 154, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_VFUNCTAB, type = 0x20FC
list[1] = LF_METHOD, count = 3, list = 0x22D0, name = 'MxVariable'
list[2] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F0F,
vfptr offset = 0, name = 'GetValue'
list[3] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F10,
vfptr offset = 4, name = 'SetValue'
list[4] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1F11,
vfptr offset = 8, name = '~MxVariable'
list[5] = LF_ONEMETHOD, public, VANILLA, index = 0x22D3, name = 'GetKey'
list[6] = LF_MEMBER, protected, type = 0x14DB, offset = 4
member name = 'm_key'
list[7] = LF_MEMBER, protected, type = 0x14DB, offset = 20
member name = 'm_value'
0x22d5 : Length = 34, Leaf = 0x1504 LF_CLASS
# members = 10, field list type 0x22d4, CONSTRUCTOR,
Derivation list type 0x0000, VT shape type 0x20fb
Size = 36, class name = MxVariable, UDT(0x00004041)
0x3c45 : Length = 50, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_ENUMERATE, public, value = 1, name = 'c_read'
list[1] = LF_ENUMERATE, public, value = 2, name = 'c_write'
list[2] = LF_ENUMERATE, public, value = 4, name = 'c_text'
0x3cc2 : Length = 38, Leaf = 0x1507 LF_ENUM
# members = 64, type = T_INT4(0074) field list type 0x3cc1
NESTED, enum name = JukeBox::JukeBoxScript, UDT(0x00003cc2)
0x3fab : Length = 10, Leaf = 0x1002 LF_POINTER
Pointer (NEAR32), Size: 0
Element type : 0x3FAA
0x405f : Length = 158, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_VFUNCTAB, type = 0x2090
list[1] = LF_ONEMETHOD, public, VANILLA, index = 0x176A, name = 'MxCore'
list[2] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x176A,
vfptr offset = 0, name = '~MxCore'
list[3] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x176B,
vfptr offset = 4, name = 'Notify'
list[4] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x2087,
vfptr offset = 8, name = 'Tickle'
list[5] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x202F,
vfptr offset = 12, name = 'ClassName'
list[6] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x2030,
vfptr offset = 16, name = 'IsA'
list[7] = LF_ONEMETHOD, public, VANILLA, index = 0x2091, name = 'GetId'
list[8] = LF_MEMBER, private, type = T_UINT4(0075), offset = 4
member name = 'm_id'
0x4060 : Length = 30, Leaf = 0x1504 LF_CLASS
# members = 9, field list type 0x405f, CONSTRUCTOR,
Derivation list type 0x0000, VT shape type 0x1266
Size = 8, class name = MxCore, UDT(0x00004060)
0x4262 : Length = 14, Leaf = 0x1503 LF_ARRAY
Element type = 0x3CC2
Index type = T_SHORT(0011)
length = 24
Name =
0x432f : Length = 14, Leaf = 0x1503 LF_ARRAY
Element type = T_INT4(0074)
Index type = T_SHORT(0011)
length = 12
Name =
0x4db5 : Length = 246, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_BCLASS, public, type = 0x1220, offset = 0
list[1] = LF_METHOD, count = 3, list = 0x14E3, name = 'MxString'
list[2] = LF_ONEMETHOD, public, VIRTUAL, index = 0x14DE, name = '~MxString'
list[3] = LF_METHOD, count = 2, list = 0x14E7, name = 'operator='
list[4] = LF_ONEMETHOD, public, VANILLA, index = 0x14DE, name = 'ToUpperCase'
list[5] = LF_ONEMETHOD, public, VANILLA, index = 0x14DE, name = 'ToLowerCase'
list[6] = LF_ONEMETHOD, public, VANILLA, index = 0x14E8, name = 'operator+'
list[7] = LF_ONEMETHOD, public, VANILLA, index = 0x14E9, name = 'operator+='
list[8] = LF_ONEMETHOD, public, VANILLA, index = 0x14EB, name = 'Compare'
list[9] = LF_ONEMETHOD, public, VANILLA, index = 0x14EC, name = 'GetData'
list[10] = LF_ONEMETHOD, public, VANILLA, index = 0x4DB4, name = 'GetLength'
list[11] = LF_MEMBER, private, type = T_32PRCHAR(0470), offset = 8
member name = 'm_data'
list[12] = LF_MEMBER, private, type = T_USHORT(0021), offset = 12
member name = 'm_length'
0x4dee : Length = 406, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_VBCLASS, public, direct base type = 0x15EA
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 3
list[1] = LF_IVBCLASS, public, indirect base type = 0x1183
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 1
list[2] = LF_IVBCLASS, public, indirect base type = 0x1468
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 2
list[3] = LF_VFUNCTAB, type = 0x2B95
list[4] = LF_ONEMETHOD, public, VANILLA, index = 0x15C2, name = 'LegoRaceMap'
list[5] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15C3, name = '~LegoRaceMap'
list[6] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15C5, name = 'Notify'
list[7] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15C4, name = 'ParseAction'
list[8] = LF_ONEMETHOD, public, VIRTUAL, index = 0x4DED, name = 'VTable0x70'
list[9] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x15C2,
vfptr offset = 0, name = 'FUN_1005d4b0'
list[10] = LF_MEMBER, private, type = T_UCHAR(0020), offset = 8
member name = 'm_parentClass2Field1'
list[11] = LF_MEMBER, private, type = T_32PVOID(0403), offset = 12
member name = 'm_parentClass2Field2'
0x4def : Length = 34, Leaf = 0x1504 LF_CLASS
# members = 21, field list type 0x4dee, CONSTRUCTOR,
Derivation list type 0x0000, VT shape type 0x12a0
Size = 436, class name = LegoRaceMap, UDT(0x00004def)
0x4db6 : Length = 30, Leaf = 0x1504 LF_CLASS
# members = 16, field list type 0x4db5, CONSTRUCTOR, OVERLOAD,
Derivation list type 0x0000, VT shape type 0x1266
Size = 16, class name = MxString, UDT(0x00004db6)
0x5591 : Length = 570, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_VBCLASS, public, direct base type = 0x15EA
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 3
list[1] = LF_IVBCLASS, public, indirect base type = 0x1183
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 1
list[2] = LF_IVBCLASS, public, indirect base type = 0x1468
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 2
list[3] = LF_VFUNCTAB, type = 0x4E11
list[4] = LF_ONEMETHOD, public, VANILLA, index = 0x1ABD, name = 'LegoCarRaceActor'
list[5] = LF_ONEMETHOD, public, VIRTUAL, index = 0x1AE0, name = 'ClassName'
list[6] = LF_ONEMETHOD, public, VIRTUAL, index = 0x1AE1, name = 'IsA'
list[7] = LF_ONEMETHOD, public, VIRTUAL, index = 0x1ADD, name = 'VTable0x6c'
list[8] = LF_ONEMETHOD, public, VIRTUAL, index = 0x1ADB, name = 'VTable0x70'
list[9] = LF_ONEMETHOD, public, VIRTUAL, index = 0x1ADA, name = 'SwitchBoundary'
list[10] = LF_ONEMETHOD, public, VIRTUAL, index = 0x1ADC, name = 'VTable0x9c'
list[11] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x558E,
vfptr offset = 0, name = 'FUN_10080590'
list[12] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1AD8,
vfptr offset = 4, name = 'FUN_10012bb0'
list[13] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1AD9,
vfptr offset = 8, name = 'FUN_10012bc0'
list[14] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1AD8,
vfptr offset = 12, name = 'FUN_10012bd0'
list[15] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1AD9,
vfptr offset = 16, name = 'FUN_10012be0'
list[16] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1AD8,
vfptr offset = 20, name = 'FUN_10012bf0'
list[17] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1AD9,
vfptr offset = 24, name = 'FUN_10012c00'
list[18] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x1ABD,
vfptr offset = 28, name = 'VTable0x1c'
list[19] = LF_MEMBER, protected, type = T_REAL32(0040), offset = 8
member name = 'm_parentClass1Field1'
list[25] = LF_ONEMETHOD, public, VIRTUAL, (compgenx), index = 0x15D1, name = '~LegoCarRaceActor'
0x5592 : Length = 38, Leaf = 0x1504 LF_CLASS
# members = 26, field list type 0x5591, CONSTRUCTOR,
Derivation list type 0x0000, VT shape type 0x34c7
Size = 416, class name = LegoCarRaceActor, UDT(0x00005592)
0x5593 : Length = 638, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_BCLASS, public, type = 0x5592, offset = 0
list[1] = LF_BCLASS, public, type = 0x4DEF, offset = 32
list[2] = LF_IVBCLASS, public, indirect base type = 0x1183
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 1
list[3] = LF_IVBCLASS, public, indirect base type = 0x1468
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 2
list[4] = LF_IVBCLASS, public, indirect base type = 0x15EA
virtual base ptr = 0x43E9, vbpoff = 4, vbind = 3
list[5] = LF_ONEMETHOD, public, VANILLA, index = 0x15CD, name = 'LegoRaceCar'
list[6] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15CE, name = '~LegoRaceCar'
list[7] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15D2, name = 'Notify'
list[8] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15E8, name = 'ClassName'
list[9] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15E9, name = 'IsA'
list[10] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15D5, name = 'ParseAction'
list[11] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15D3, name = 'SetWorldSpeed'
list[12] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15DF, name = 'VTable0x6c'
list[13] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15D3, name = 'VTable0x70'
list[14] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15DC, name = 'VTable0x94'
list[15] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15E5, name = 'SwitchBoundary'
list[16] = LF_ONEMETHOD, public, VIRTUAL, index = 0x15DD, name = 'VTable0x9c'
list[17] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x15D4,
vfptr offset = 32, name = 'SetMaxLinearVelocity'
list[18] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x15D4,
vfptr offset = 36, name = 'FUN_10012ff0'
list[19] = LF_ONEMETHOD, public, INTRODUCING VIRTUAL, index = 0x5588,
vfptr offset = 40, name = 'HandleSkeletonKicks'
list[20] = LF_MEMBER, private, type = T_UCHAR(0020), offset = 84
member name = 'm_childClassField'
0x5594 : Length = 34, Leaf = 0x1504 LF_CLASS
# members = 30, field list type 0x5593, CONSTRUCTOR,
Derivation list type 0x0000, VT shape type 0x2d1e
Size = 512, class name = LegoRaceCar, UDT(0x000055bb)
"""
@pytest.fixture(name="parser")
def types_parser_fixture():
parser = CvdumpTypesParser()
for line in TEST_LINES.split("\n"):
parser.read_line(line)
return parser
def test_basic_parsing(parser: CvdumpTypesParser):
obj = parser.keys["0x4db6"]
assert obj["type"] == "LF_CLASS"
assert obj["name"] == "MxString"
assert obj["udt"] == "0x4db6"
assert len(parser.keys["0x4db5"]["members"]) == 2
def test_scalar_types(parser: CvdumpTypesParser):
"""Full tests on the scalar_* methods are in another file.
Here we are just testing the passthrough of the "T_" types."""
assert parser.get("T_CHAR").name is None
assert parser.get("T_CHAR").size == 1
assert parser.get("T_32PVOID").name is None
assert parser.get("T_32PVOID").size == 4
def test_resolve_forward_ref(parser: CvdumpTypesParser):
# Non-forward ref
assert parser.get("0x22d5").name == "MxVariable"
# Forward ref
assert parser.get("0x14db").name == "MxString"
assert parser.get("0x14db").size == 16
def test_members(parser: CvdumpTypesParser):
"""Return the list of items to compare for a given complex type.
If the class has a superclass, add those members too."""
# MxCore field list
mxcore_members = parser.get_scalars("0x405f")
assert mxcore_members == [
(0, "vftable", "T_32PVOID"),
(4, "m_id", "T_UINT4"),
]
# MxCore class id. Should be the same members
assert mxcore_members == parser.get_scalars("0x4060")
# MxString field list. Should add inherited members from MxCore
assert parser.get_scalars("0x4db5") == [
(0, "vftable", "T_32PVOID"),
(4, "m_id", "T_UINT4"),
(8, "m_data", "T_32PRCHAR"),
(12, "m_length", "T_USHORT"),
]
# LegoRaceCar with multiple superclasses
assert parser.get("0x5594").members == [
FieldListItem(offset=0, name="vftable", type="T_32PVOID"),
FieldListItem(offset=0, name="vftable", type="T_32PVOID"),
FieldListItem(offset=8, name="m_parentClass1Field1", type="T_REAL32"),
FieldListItem(offset=8, name="m_parentClass2Field1", type="T_UCHAR"),
FieldListItem(offset=12, name="m_parentClass2Field2", type="T_32PVOID"),
FieldListItem(offset=84, name="m_childClassField", type="T_UCHAR"),
]
def test_virtual_base_classes(parser: CvdumpTypesParser):
"""Make sure that virtual base classes are parsed correctly."""
lego_car_race_actor = parser.keys.get("0x5591")
assert lego_car_race_actor is not None
assert lego_car_race_actor["vbase"] == VirtualBasePointer(
vboffset=4,
bases=[
VirtualBaseClass(type="0x1183", index=1, direct=False),
VirtualBaseClass(type="0x1468", index=2, direct=False),
VirtualBaseClass(type="0x15EA", index=3, direct=True),
],
)
def test_members_recursive(parser: CvdumpTypesParser):
"""Make sure that we unwrap the dependency tree correctly."""
# MxVariable field list
assert parser.get_scalars("0x22d4") == [
(0, "vftable", "T_32PVOID"),
(4, "m_key.vftable", "T_32PVOID"),
(8, "m_key.m_id", "T_UINT4"),
(12, "m_key.m_data", "T_32PRCHAR"),
(16, "m_key.m_length", "T_USHORT"), # with padding
(20, "m_value.vftable", "T_32PVOID"),
(24, "m_value.m_id", "T_UINT4"),
(28, "m_value.m_data", "T_32PRCHAR"),
(32, "m_value.m_length", "T_USHORT"), # with padding
]
def test_struct(parser: CvdumpTypesParser):
"""Basic test for converting type into struct.unpack format string."""
# MxCore: vftable and uint32. The vftable pointer is read as uint32.
assert parser.get_format_string("0x4060") == "<LL"
# _D3DVECTOR, three floats. Union types should already be removed.
assert parser.get_format_string("0x10e1") == "<fff"
# MxRect32, four signed ints.
assert parser.get_format_string("0x1214") == "<llll"
def test_struct_padding(parser: CvdumpTypesParser):
"""For data comparison purposes, make sure we have no gaps in the
list of scalar types. Any gap is filled by an unsigned char."""
# MxString, padded to 16 bytes. 4 actual members. 2 bytes of padding.
assert len(parser.get_scalars("0x4db6")) == 4
assert len(parser.get_scalars_gapless("0x4db6")) == 6
# MxVariable, with two MxStrings (and a vtable)
# Fill in the middle gap and the outer gap.
assert len(parser.get_scalars("0x22d5")) == 9
assert len(parser.get_scalars_gapless("0x22d5")) == 13
def test_struct_format_string(parser: CvdumpTypesParser):
"""Generate the struct.unpack format string using the
list of scalars with padding filled in."""
# MxString, padded to 16 bytes.
assert parser.get_format_string("0x4db6") == "<LLLHBB"
# MxVariable, with two MxString members.
assert parser.get_format_string("0x22d5") == "<LLLLHBBLLLHBB"
def test_array(parser: CvdumpTypesParser):
"""LF_ARRAY members are created dynamically based on the
total array size and the size of one element."""
# unsigned char[8]
assert parser.get_scalars("0x10e4") == [
(0, "[0]", "T_UCHAR"),
(1, "[1]", "T_UCHAR"),
(2, "[2]", "T_UCHAR"),
(3, "[3]", "T_UCHAR"),
(4, "[4]", "T_UCHAR"),
(5, "[5]", "T_UCHAR"),
(6, "[6]", "T_UCHAR"),
(7, "[7]", "T_UCHAR"),
]
# float[4]
assert parser.get_scalars("0x103b") == [
(0, "[0]", "T_REAL32"),
(4, "[1]", "T_REAL32"),
(8, "[2]", "T_REAL32"),
(12, "[3]", "T_REAL32"),
]
def test_2d_array(parser: CvdumpTypesParser):
"""Make sure 2d array elements are named as we expect."""
# float[4][4]
float_array = parser.get_scalars("0x103c")
assert len(float_array) == 16
assert float_array[0] == (0, "[0][0]", "T_REAL32")
assert float_array[1] == (4, "[0][1]", "T_REAL32")
assert float_array[4] == (16, "[1][0]", "T_REAL32")
assert float_array[-1] == (60, "[3][3]", "T_REAL32")
def test_enum(parser: CvdumpTypesParser):
"""LF_ENUM should equal 4-byte int"""
assert parser.get("0x3cc2").size == 4
assert parser.get_scalars("0x3cc2") == [(0, None, "T_INT4")]
# Now look at an array of enum, 24 bytes
enum_array = parser.get_scalars("0x4262")
assert len(enum_array) == 6 # 24 / 4
assert enum_array[0].size == 4
def test_lf_pointer(parser: CvdumpTypesParser):
"""LF_POINTER is just a wrapper for scalar pointer type"""
assert parser.get("0x3fab").size == 4
# assert parser.get("0x3fab").is_pointer is True # TODO: ?
assert parser.get_scalars("0x3fab") == [(0, None, "T_32PVOID")]
def test_key_not_exist(parser: CvdumpTypesParser):
"""Accessing a non-existent type id should raise our exception"""
with pytest.raises(CvdumpKeyError):
parser.get("0xbeef")
with pytest.raises(CvdumpKeyError):
parser.get_scalars("0xbeef")
def test_broken_forward_ref(parser: CvdumpTypesParser):
"""Raise an exception if we cannot follow a forward reference"""
# Verify forward reference on MxCore
parser.get("0x1220")
# Delete the MxCore LF_CLASS
del parser.keys["0x4060"]
# Forward ref via 0x1220 will fail
with pytest.raises(CvdumpKeyError):
parser.get("0x1220")
def test_null_forward_ref(parser: CvdumpTypesParser):
"""If the forward ref object is invalid and has no forward ref id,
raise an exception."""
# Test MxString forward reference
parser.get("0x14db")
# Delete the UDT for MxString
del parser.keys["0x14db"]["udt"]
# Cannot complete the forward reference lookup
with pytest.raises(CvdumpIntegrityError):
parser.get("0x14db")
def test_broken_array_element_ref(parser: CvdumpTypesParser):
# Test LF_ARRAY of ROIColorAlias
parser.get("0x19b1")
# Delete ROIColorAlias
del parser.keys["0x19b0"]
# Type reference lookup will fail
with pytest.raises(CvdumpKeyError):
parser.get("0x19b1")
def test_lf_modifier(parser: CvdumpTypesParser):
"""Is this an alias for another type?"""
# Modifies float
assert parser.get("0x1028").size == 4
assert parser.get_scalars("0x1028") == [(0, None, "T_REAL32")]
mxrect = parser.get_scalars("0x1214")
# Modifies MxRect32 via forward ref
assert mxrect == parser.get_scalars("0x11f2")
def test_union_members(parser: CvdumpTypesParser):
"""If there is a union somewhere in our dependency list, we can
expect to see duplicated member offsets and names. This is ok for
the TypeInfo tuple, but the list of ScalarType items should have
unique offset to simplify comparison."""
# D3DVector type with duplicated offsets
d3dvector = parser.get("0x10e1")
assert d3dvector.members is not None
assert len(d3dvector.members) == 6
assert len([m for m in d3dvector.members if m.offset == 0]) == 2
# Deduplicated comparison list
vector_items = parser.get_scalars("0x10e1")
assert len(vector_items) == 3
def test_arglist(parser: CvdumpTypesParser):
arglist = parser.keys["0x1018"]
assert arglist["argcount"] == 3
assert arglist["args"] == ["0x100D", "0x1016", "0x1017"]
def test_procedure(parser: CvdumpTypesParser):
procedure = parser.keys["0x1019"]
assert procedure == {
"type": "LF_PROCEDURE",
"return_type": "T_LONG(0012)",
"call_type": "C Near",
"func_attr": "none",
"num_params": "3",
"arg_list_type": "0x1018",
}
def test_mfunction(parser: CvdumpTypesParser):
mfunction = parser.keys["0x101e"]
assert mfunction == {
"type": "LF_MFUNCTION",
"return_type": "T_CHAR(0010)",
"class_type": "0x101A",
"this_type": "0x101B",
"call_type": "ThisCall",
"func_attr": "none",
"num_params": "2",
"arg_list_type": "0x101d",
"this_adjust": "0",
}
def test_union_forward_ref(parser: CvdumpTypesParser):
union = parser.keys["0x2339"]
assert union["is_forward_ref"] is True
assert union["udt"] == "0x2e85"
def test_union(parser: CvdumpTypesParser):
union = parser.keys["0x2e85"]
assert union == {
"type": "LF_UNION",
"name": "FlagBitfield",
"size": 1,
"udt": "0x2e85",
}
def test_fieldlist_enumerate(parser: CvdumpTypesParser):
fieldlist_enum = parser.keys["0x3c45"]
assert fieldlist_enum == {
"type": "LF_FIELDLIST",
"variants": [
{"name": "c_read", "value": 1},
{"name": "c_write", "value": 2},
{"name": "c_text", "value": 4},
],
}
UNNAMED_UNION_DATA = """
0x369d : Length = 34, Leaf = 0x1203 LF_FIELDLIST
list[0] = LF_MEMBER, public, type = T_32PRCHAR(0470), offset = 0
member name = 'sz'
list[1] = LF_MEMBER, public, type = T_32PUSHORT(0421), offset = 0
member name = 'wz'
0x369e : Length = 22, Leaf = 0x1506 LF_UNION
# members = 2, field list type 0x369d, NESTED, Size = 4 ,class name = __unnamed
"""
def test_unnamed_union():
"""Make sure we can parse anonymous union types without a UDT"""
parser = CvdumpTypesParser()
for line in UNNAMED_UNION_DATA.split("\n"):
parser.read_line(line)
# Make sure we can parse the members line
union = parser.keys["0x369e"]
assert union["size"] == 4

View file

@ -1,83 +0,0 @@
import pytest
from isledecomp.cvdump.demangler import (
demangle_string_const,
demangle_vtable,
parse_encoded_number,
InvalidEncodedNumberError,
get_vtordisp_name,
)
string_demangle_cases = [
("??_C@_08LIDF@December?$AA@", 8, False),
("??_C@_0L@EGPP@english?9nz?$AA@", 11, False),
(
"??_C@_1O@POHA@?$AA?$CI?$AAn?$AAu?$AAl?$AAl?$AA?$CJ?$AA?$AA?$AA?$AA?$AA?$AH?$AA?$AA?$AA?$AA?$AA?$AA?$AA?$9A?$AE?$;I@",
14,
True,
),
("??_C@_00A@?$AA@", 0, False),
("??_C@_01A@?$AA?$AA@", 1, False),
]
@pytest.mark.parametrize("symbol, strlen, is_utf16", string_demangle_cases)
def test_strings(symbol, is_utf16, strlen):
s = demangle_string_const(symbol)
assert s.len == strlen
assert s.is_utf16 == is_utf16
encoded_numbers = [
("A@", 0),
("AA@", 0), # would never happen?
("P@", 15),
("BA@", 16),
("BCD@", 291),
]
@pytest.mark.parametrize("string, value", encoded_numbers)
def test_encoded_numbers(string, value):
assert parse_encoded_number(string) == value
def test_invalid_encoded_number():
with pytest.raises(InvalidEncodedNumberError):
parse_encoded_number("Hello")
vtable_cases = [
("??_7LegoCarBuildAnimPresenter@@6B@", "LegoCarBuildAnimPresenter::`vftable'"),
("??_7?$MxCollection@PAVLegoWorld@@@@6B@", "MxCollection<LegoWorld *>::`vftable'"),
(
"??_7?$MxPtrList@VLegoPathController@@@@6B@",
"MxPtrList<LegoPathController>::`vftable'",
),
("??_7Renderer@Tgl@@6B@", "Tgl::Renderer::`vftable'"),
("??_7LegoExtraActor@@6B0@@", "LegoExtraActor::`vftable'{for `LegoExtraActor'}"),
(
"??_7LegoExtraActor@@6BLegoAnimActor@@@",
"LegoExtraActor::`vftable'{for `LegoAnimActor'}",
),
(
"??_7LegoAnimActor@@6B?$LegoContainer@PAM@@@",
"LegoAnimActor::`vftable'{for `LegoContainer<float *>'}",
),
]
@pytest.mark.parametrize("symbol, class_name", vtable_cases)
def test_vtable(symbol, class_name):
assert demangle_vtable(symbol) == class_name
def test_vtordisp():
"""Make sure we can accurately detect an adjuster thunk symbol"""
assert get_vtordisp_name("") is None
assert get_vtordisp_name("?ClassName@LegoExtraActor@@UBEPBDXZ") is None
assert (
get_vtordisp_name("?ClassName@LegoExtraActor@@$4PPPPPPPM@A@BEPBDXZ") is not None
)
# A function called vtordisp
assert get_vtordisp_name("?vtordisp@LegoExtraActor@@UBEPBDXZ") is None

View file

@ -1,212 +0,0 @@
from isledecomp.compare.asm.instgen import InstructGen, SectionType
def test_ret():
"""Make sure we can handle a function with one instruction."""
ig = InstructGen(b"\xc3", 0)
assert len(ig.sections) == 1
SCORE_NOTIFY = (
b"\x53\x56\x57\x8b\xd9\x33\xff\x8b\x74\x24\x10\x56\xe8\xbf\xe1\x01"
b"\x00\x80\xbb\xf6\x00\x00\x00\x00\x0f\x84\x9c\x00\x00\x00\x8b\x4e"
b"\x04\x49\x83\xf9\x17\x0f\x87\x8f\x00\x00\x00\x33\xc0\x8a\x81\xec"
b"\x14\x00\x10\xff\x24\x85\xd4\x14\x00\x10\x8b\xcb\xbf\x01\x00\x00"
b"\x00\xe8\x7a\x05\x00\x00\x8b\xc7\x5f\x5e\x5b\xc2\x04\x00\x56\x8b"
b"\xcb\xe8\xaa\x00\x00\x00\x8b\xf8\x8b\xc7\x5f\x5e\x5b\xc2\x04\x00"
b"\x80\x7e\x18\x20\x75\x07\x8b\xcb\xe8\xc3\xfe\xff\xff\xbf\x01\x00"
b"\x00\x00\x8b\xc7\x5f\x5e\x5b\xc2\x04\x00\x56\x8b\xcb\xe8\x3e\x02"
b"\x00\x00\x8b\xf8\x8b\xc7\x5f\x5e\x5b\xc2\x04\x00\x6a\x09\xa1\x4c"
b"\x45\x0f\x10\x6a\x07\x50\xe8\x35\x45\x01\x00\x83\xc4\x0c\x8b\x83"
b"\xf8\x00\x00\x00\x85\xc0\x74\x0d\x50\xe8\xa2\x42\x01\x00\x8b\xc8"
b"\xe8\x9b\x9b\x03\x00\xbf\x01\x00\x00\x00\x8b\xc7\x5f\x5e\x5b\xc2"
b"\x04\x00\x8b\xff\x4a\x14\x00\x10\x5e\x14\x00\x10\x70\x14\x00\x10"
b"\x8a\x14\x00\x10\x9c\x14\x00\x10\xca\x14\x00\x10\x00\x01\x05\x05"
b"\x05\x05\x02\x05\x05\x05\x05\x05\x05\x05\x05\x05\x03\x05\x05\x05"
b"\x05\x05\x05\x04\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc"
)
def test_score_notify():
"""Score::Notify function from 0x10001410 in LEGO1.
Good representative function for jump table (at 0x100014d4)
and switch data (at 0x100014ec)."""
ig = InstructGen(SCORE_NOTIFY, 0x10001410)
# Did we get everything?
assert len(ig.sections) == 3
types_only = tuple(s.type for s in ig.sections)
assert types_only == (SectionType.CODE, SectionType.ADDR_TAB, SectionType.DATA_TAB)
# CODE section stopped at correct place?
instructions = ig.sections[0].contents
assert instructions[-1].address == 0x100014D2
# n.b. 0x100014d2 is the dummy instruction `mov edi, edi`
# Ghidra does more thorough analysis and ignores this.
# The last real instruction should be at 0x100014cf. Not a big deal
# to include this because it is not junk data.
# 6 switch addresses
assert len(ig.sections[1].contents) == 6
# TODO: The data table at the end includes all of the 0xCC padding bytes.
SMACK_CASE = (
# LEGO1: 0x100cdc43 (modified so jump table points at +0x1016)
b"\x2e\xff\x24\x8d\x16\x10\x00\x00"
# LEGO1: 0x100cdb62 (instructions before and after jump table)
b"\x8b\xf8\xeb\x1a\x87\xdb\x87\xc9\x87\xdb\x87\xc9\x87\xdb\x50\xdc"
b"\x0c\x10\xd0\xe2\x0c\x10\xb0\xe8\x0c\x10\x50\xe9\x0c\x10\xa0\x10"
b"\x27\x10\x10\x3c\x11\x77\x17\x8a\xc8"
)
def test_smack_case():
"""Case where we have code / jump table / code.
Need to properly separate code sections, eliminate junk instructions
and continue disassembling at the proper address following the data."""
ig = InstructGen(SMACK_CASE, 0x1000)
assert len(ig.sections) == 3
assert ig.sections[0].type == ig.sections[2].type == SectionType.CODE
# Make sure we captured the instruction immediately after
assert ig.sections[2].contents[0].mnemonic == "mov"
# BETA10 0x1004c9cc
BETA_FUNC = (
b"\x55\x8b\xec\x83\xec\x08\x53\x56\x57\x89\x4d\xfc\x8b\x45\xfc\x33"
b"\xc9\x8a\x88\x19\x02\x00\x00\x89\x4d\xf8\xe9\x1e\x00\x00\x00\xe9"
b"\x41\x00\x00\x00\xe9\x3c\x00\x00\x00\xe9\x37\x00\x00\x00\xe9\x32"
b"\x00\x00\x00\xe9\x2d\x00\x00\x00\xe9\x28\x00\x00\x00\x83\x7d\xf8"
b"\x04\x0f\x87\x1e\x00\x00\x00\x8b\x45\xf8\xff\x24\x85\x1d\xca\x04"
b"\x10\xeb\xc9\x04\x10\xf0\xc9\x04\x10\xf5\xc9\x04\x10\xfa\xc9\x04"
b"\x10\xff\xc9\x04\x10\xb0\x01\xe9\x00\x00\x00\x00\x5f\x5e\x5b\xc9"
b"\xc2\x04\x00"
)
def test_beta_case():
"""Complete (and short) function with CODE / ADDR / CODE"""
ig = InstructGen(BETA_FUNC, 0x1004C9CC)
# The JMP into the jump table immediately precedes the jump table.
# We have to detect this and switch sections correctly or we will only
# get 1 section.
assert len(ig.sections) == 3
assert ig.sections[0].type == ig.sections[2].type == SectionType.CODE
# Make sure we captured the instruction immediately after
assert ig.sections[2].contents[0].mnemonic == "mov"
# LEGO1 0x1000fb50
# TODO: The test data here is longer than it needs to be.
THUNK_TEST = (
b"\x2b\x49\xfc\xe9\x08\x00\x00\x00\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc"
b"\x56\x8b\xf1\xe8\xd8\xc5\x00\x00\x8b\xce\xe8\xb1\xdc\x01\x00\xf6"
b"\x44\x24\x08\x01\x74\x0c\x8d\x46\xe0\x50\xe8\xe1\x66\x07\x00\x83"
b"\xc4\x04\x8d\x46\xe0\x5e\xc2\x04\x00\xcc\xcc\xcc\xcc\xcc\xcc\xcc"
b"\x2b\x49\xfc\xe9\x08\x00\x00\x00\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc"
b"\xb8\x7c\x05\x0f\x10\xc3\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc"
b"\x2b\x49\xfc\xe9\x08\x00\x00\x00\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc"
b"\x8b\x54"
# The problem is here: the last two bytes are the start of the next
# function 0x1000fbc0. This is not enough data to read an instruction.
)
def test_thunk_case():
"""Adjuster thunk incorrectly annotated.
We are reading way more bytes than we should for this function."""
ig = InstructGen(THUNK_TEST, 0x1000FB50)
# No switch cases here, so the only section is code.
# This caused an infinite loop during testing so the goal is just to finish.
assert len(ig.sections) == 1
# TODO: We might detect the 0xCC padding bytes and cut off the function.
# If we did that, we would correctly read only 2 instructions.
# assert len(ig.sections[0].contents) == 2
# LEGO1 0x1006f080, Infocenter::HandleEndAction
HANDLE_END_ACTION = (
b"\x53\x56\x57\x8b\xf1\x8b\x5c\x24\x10\x8b\x0d\x84\x45\x0f\x10\x8b"
b"\x7b\x0c\x8b\x47\x20\x39\x01\x75\x29\x81\x7f\x1c\xf3\x01\x00\x00"
b"\x75\x20\xe8\x59\x66\xfa\xff\x6a\x00\x8b\x40\x18\x6a\x00\x6a\x10"
b"\x50\xff\x15\x38\xb5\x10\x10\xb8\x01\x00\x00\x00\x5f\x5e\x5b\xc2"
b"\x04\x00\x39\x46\x0c\x0f\x85\xa2\x00\x00\x00\x8b\x47\x1c\x83\xf8"
b"\x28\x74\x18\x83\xf8\x29\x74\x13\x83\xf8\x2a\x74\x0e\x83\xf8\x2b"
b"\x74\x09\x83\xf8\x2c\x0f\x85\x82\x00\x00\x00\x66\x8b\x86\xd4\x01"
b"\x00\x00\x66\x85\xc0\x74\x09\x66\x48\x66\x89\x86\xd4\x01\x00\x00"
b"\x66\x83\xbe\xd4\x01\x00\x00\x00\x75\x63\x6a\x0b\xe8\xff\x67\xfa"
b"\xff\x66\x8b\x86\xfc\x00\x00\x00\x83\xc4\x04\x50\xe8\x3f\x66\xfa"
b"\xff\x8b\xc8\xe8\x58\xa6\xfc\xff\x0f\xbf\x86\xfc\x00\x00\x00\x48"
b"\x83\xf8\x04\x77\x2f\xff\x24\x85\x78\xf4\x06\x10\x68\x1d\x02\x00"
b"\x00\xeb\x1a\x68\x1e\x02\x00\x00\xeb\x13\x68\x1f\x02\x00\x00\xeb"
b"\x0c\x68\x20\x02\x00\x00\xeb\x05\x68\x21\x02\x00\x00\x8b\xce\xe8"
b"\x9c\x21\x00\x00\x6a\x01\x8b\xce\xe8\x53\x1c\x00\x00\x8d\x8e\x0c"
b"\x01\x00\x00\x53\x8b\x01\xff\x50\x04\x85\xc0\x0f\x85\xef\x02\x00"
b"\x00\x8b\x56\x0c\x8b\x4f\x20\x3b\xd1\x74\x0e\x8b\x1d\x74\x45\x0f"
b"\x10\x39\x0b\x0f\x85\xd7\x02\x00\x00\x81\x7f\x1c\x02\x02\x00\x00"
b"\x75\x1a\x6a\x00\x52\x6a\x10\xe8\xa4\x65\xfa\xff\x8b\xc8\xe8\x0d"
b"\xa2\xfb\xff\x66\xc7\x86\xd6\x01\x00\x00\x00\x00\x8b\x96\x00\x01"
b"\x00\x00\x8d\x42\x74\x8b\x18\x83\xfb\x0c\x0f\x87\x9b\x02\x00\x00"
b"\x33\xc9\x8a\x8b\xac\xf4\x06\x10\xff\x24\x8d\x8c\xf4\x06\x10\x8b"
b"\x86\x08\x01\x00\x00\x83\xf8\x05\x77\x07\xff\x24\x85\xbc\xf4\x06"
b"\x10\x8b\xce\xe8\xb8\x1a\x00\x00\x8b\x86\x00\x01\x00\x00\x68\xf4"
b"\x01\x00\x00\x8b\xce\xc7\x40\x74\x0b\x00\x00\x00\xe8\xef\x20\x00"
b"\x00\x8b\x86\x00\x01\x00\x00\xc7\x86\x08\x01\x00\x00\xff\xff\xff"
b"\xff\x83\x78\x78\x00\x0f\x85\x40\x02\x00\x00\xb8\x01\x00\x00\x00"
b"\x5f\x66\xc7\x86\xd2\x01\x00\x00\x01\x00\x5e\x5b\xc2\x04\x00\x6a"
b"\x00\x8b\xce\x6a\x01\xe8\xd6\x19\x00\x00\xb8\x01\x00\x00\x00\x5f"
b"\x5e\x5b\xc2\x04\x00\x6a\x01\x8b\xce\x6a\x02\xe8\xc0\x19\x00\x00"
b"\xb8\x01\x00\x00\x00\x5f\x5e\x5b\xc2\x04\x00\x8b\xce\xe8\x3e\x1a"
b"\x00\x00\x8b\x86\x00\x01\x00\x00\x68\x1c\x02\x00\x00\x8b\xce\xc7"
b"\x40\x74\x0b\x00\x00\x00\xe8\x75\x20\x00\x00\xb8\x01\x00\x00\x00"
b"\x5f\xc7\x86\x08\x01\x00\x00\xff\xff\xff\xff\x5e\x5b\xc2\x04\x00"
b"\x8b\xce\xe8\x09\x1a\x00\x00\x8b\x86\x00\x01\x00\x00\x68\x1b\x02"
b"\x00\x00\x8b\xce\xc7\x40\x74\x0b\x00\x00\x00\xe8\x40\x20\x00\x00"
b"\xb8\x01\x00\x00\x00\x5f\xc7\x86\x08\x01\x00\x00\xff\xff\xff\xff"
b"\x5e\x5b\xc2\x04\x00\xc7\x00\x0b\x00\x00\x00\x8b\x86\x08\x01\x00"
b"\x00\x83\xf8\x04\x74\x0c\x83\xf8\x05\x74\x0e\x68\xf4\x01\x00\x00"
b"\xeb\x0c\x68\x1c\x02\x00\x00\xeb\x05\x68\x1b\x02\x00\x00\x8b\xce"
b"\xe8\xfb\x1f\x00\x00\xb8\x01\x00\x00\x00\x5f\xc7\x86\x08\x01\x00"
b"\x00\xff\xff\xff\xff\x5e\x5b\xc2\x04\x00\x6a\x00\xa1\xa0\x76\x0f"
b"\x10\x50\xe8\x39\x65\xfa\xff\x83\xc4\x08\xa1\xa4\x76\x0f\x10\x6a"
b"\x00\x50\xe8\x29\x65\xfa\xff\x83\xc4\x08\xe8\xf1\x63\xfa\xff\x8b"
b"\xc8\xe8\x6a\x02\x01\x00\xb8\x01\x00\x00\x00\x5f\x5e\x5b\xc2\x04"
b"\x00\x8b\x47\x1c\x83\xf8\x46\x74\x09\x83\xf8\x47\x0f\x85\x09\x01"
b"\x00\x00\x6a\x00\x6a\x00\x6a\x32\x6a\x03\xe8\x91\x65\xfa\xff\x8b"
b"\xc8\xe8\xfa\xc7\xfd\xff\x8b\x86\x00\x01\x00\x00\x5f\x5e\x5b\xc7"
b"\x40\x74\x0e\x00\x00\x00\xb8\x01\x00\x00\x00\xc2\x04\x00\x8b\x47"
b"\x1c\x39\x86\xf8\x00\x00\x00\x0f\x85\xce\x00\x00\x00\xe8\xbe\x63"
b"\xfa\xff\x83\x78\x10\x02\x74\x19\x66\x8b\x86\xfc\x00\x00\x00\x66"
b"\x85\xc0\x74\x0d\x50\xe8\xa6\x63\xfa\xff\x8b\xc8\xe8\xbf\xa3\xfc"
b"\xff\x6a\x00\x6a\x00\x6a\x32\x6a\x03\xe8\x32\x65\xfa\xff\x8b\xc8"
b"\xe8\x9b\xc7\xfd\xff\x8b\x86\x00\x01\x00\x00\x5f\x5e\x5b\xc7\x40"
b"\x74\x0e\x00\x00\x00\xb8\x01\x00\x00\x00\xc2\x04\x00\x83\x7a\x78"
b"\x00\x75\x32\x8b\x86\xf8\x00\x00\x00\x83\xf8\x28\x74\x27\x83\xf8"
b"\x29\x74\x22\x83\xf8\x2a\x74\x1d\x83\xf8\x2b\x74\x18\x83\xf8\x2c"
b"\x74\x13\x66\xc7\x86\xd0\x01\x00\x00\x01\x00\x6a\x0b\xe8\xee\x64"
b"\xfa\xff\x83\xc4\x04\x8b\x86\x00\x01\x00\x00\x6a\x01\x68\xdc\x44"
b"\x0f\x10\xc7\x40\x74\x02\x00\x00\x00\xe8\x22\x64\xfa\xff\x83\xc4"
b"\x08\xb8\x01\x00\x00\x00\x5f\x5e\x5b\xc2\x04\x00\x8b\x47\x1c\x39"
b"\x86\xf8\x00\x00\x00\x75\x14\x6a\x00\x6a\x00\x6a\x32\x6a\x03\xe8"
b"\x9c\x64\xfa\xff\x8b\xc8\xe8\x05\xc7\xfd\xff\xb8\x01\x00\x00\x00"
b"\x5f\x5e\x5b\xc2\x04\x00\x8b\xff\x3c\xf1\x06\x10\x43\xf1\x06\x10"
b"\x4a\xf1\x06\x10\x51\xf1\x06\x10\x58\xf1\x06\x10\xdf\xf1\x06\x10"
b"\xd5\xf2\x06\x10\x1a\xf3\x06\x10\x51\xf3\x06\x10\x8e\xf3\x06\x10"
b"\xed\xf3\x06\x10\x4c\xf4\x06\x10\x6b\xf4\x06\x10\x00\x01\x02\x07"
b"\x03\x04\x07\x07\x07\x07\x07\x05\x06\x8d\x49\x00\x3f\xf2\x06\x10"
b"\x55\xf2\x06\x10\xf1\xf1\x06\x10\xf1\xf1\x06\x10\x6b\xf2\x06\x10"
b"\xa0\xf2\x06\x10\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc"
)
def test_action_case():
"""3 switches: 3 jump tables, 1 data table"""
ig = InstructGen(HANDLE_END_ACTION, 0x1006F080)
# Two of the jump tables (0x1006f478 with 5, 0x1006f48c with 8)
# are contiguous.
assert len(ig.sections) == 5

View file

@ -1,152 +0,0 @@
"""Tests for the Bin (or IsleBin) module that:
1. Parses relevant data from the PE header and other structures.
2. Provides an interface to read from the DLL or EXE using a virtual address.
These are some basic smoke tests."""
import hashlib
from typing import Tuple
import pytest
from isledecomp.bin import (
Bin as IsleBin,
SectionNotFoundError,
InvalidVirtualAddressError,
)
# LEGO1.DLL: v1.1 English, September
LEGO1_SHA256 = "14645225bbe81212e9bc1919cd8a692b81b8622abb6561280d99b0fc4151ce17"
@pytest.fixture(name="binfile", scope="session")
def fixture_binfile(pytestconfig) -> IsleBin:
filename = pytestconfig.getoption("--lego1")
# Skip this if we have not provided the path to LEGO1.dll.
if filename is None:
pytest.skip(allow_module_level=True, reason="No path to LEGO1")
with open(filename, "rb") as f:
digest = hashlib.sha256(f.read()).hexdigest()
if digest != LEGO1_SHA256:
pytest.fail(reason="Did not match expected LEGO1.DLL")
with IsleBin(filename, find_str=True) as islebin:
yield islebin
def test_basic(binfile: IsleBin):
assert binfile.entry == 0x1008C860
assert len(binfile.sections) == 6
with pytest.raises(SectionNotFoundError):
binfile.get_section_by_name(".hello")
SECTION_INFO = (
(".text", 0x10001000, 0xD2A66, 0xD2C00),
(".rdata", 0x100D4000, 0x1B5B6, 0x1B600),
(".data", 0x100F0000, 0x1A734, 0x12C00),
(".idata", 0x1010B000, 0x1006, 0x1200),
(".rsrc", 0x1010D000, 0x21D8, 0x2200),
(".reloc", 0x10110000, 0x10C58, 0x10E00),
)
@pytest.mark.parametrize("name, v_addr, v_size, raw_size", SECTION_INFO)
def test_sections(name: str, v_addr: int, v_size: int, raw_size: int, binfile: IsleBin):
section = binfile.get_section_by_name(name)
assert section.virtual_address == v_addr
assert section.virtual_size == v_size
assert section.size_of_raw_data == raw_size
DOUBLE_PI_BYTES = b"\x18\x2d\x44\x54\xfb\x21\x09\x40"
# Now that's a lot of pi
PI_ADDRESSES = (
0x100D4000,
0x100D4700,
0x100D7180,
0x100DB8F0,
0x100DC030,
)
@pytest.mark.parametrize("addr", PI_ADDRESSES)
def test_read_pi(addr: int, binfile: IsleBin):
assert binfile.read(addr, 8) == DOUBLE_PI_BYTES
def test_unusual_reads(binfile: IsleBin):
"""Reads that return an error or some specific value based on context"""
# Reading an address earlier than the imagebase
with pytest.raises(InvalidVirtualAddressError):
binfile.read(0, 1)
# Really big address
with pytest.raises(InvalidVirtualAddressError):
binfile.read(0xFFFFFFFF, 1)
# Uninitialized part of .data
assert binfile.read(0x1010A600, 4) is None
# Past the end of virtual size in .text
assert binfile.read(0x100D3A70, 4) == b"\x00\x00\x00\x00"
STRING_ADDRESSES = (
(0x100DB588, b"November"),
(0x100F0130, b"Helicopter"),
(0x100F0144, b"HelicopterState"),
(0x100F0BE4, b"valerie"),
(0x100F4080, b"TARGET"),
)
@pytest.mark.parametrize("addr, string", STRING_ADDRESSES)
def test_strings(addr: int, string: bytes, binfile: IsleBin):
"""Test string read utility function and the string search feature"""
assert binfile.read_string(addr) == string
assert binfile.find_string(string) == addr
def test_relocation(binfile: IsleBin):
# n.b. This is not the number of *relocations* read from .reloc.
# It is the set of unique addresses in the binary that get relocated.
assert len(binfile.get_relocated_addresses()) == 14066
# Score::Score is referenced only by CALL instructions. No need to relocate.
assert binfile.is_relocated_addr(0x10001000) is False
# MxEntity::SetEntityId is in the vtable and must be relocated.
assert binfile.is_relocated_addr(0x10001070) is True
# Not sanitizing dll name case. Do we care?
IMPORT_REFS = (
("KERNEL32.dll", "CreateMutexA", 0x1010B3D0),
("WINMM.dll", "midiOutPrepareHeader", 0x1010B550),
)
@pytest.mark.parametrize("import_ref", IMPORT_REFS)
def test_imports(import_ref: Tuple[str, str, int], binfile: IsleBin):
assert import_ref in binfile.imports
# Location of the JMP instruction and the import address.
THUNKS = (
(0x100D3728, 0x1010B32C), # DirectDrawCreate
(0x10098F9E, 0x1010B3D4), # RtlUnwind
)
@pytest.mark.parametrize("thunk_ref", THUNKS)
def test_thunks(thunk_ref: Tuple[int, int], binfile: IsleBin):
assert thunk_ref in binfile.thunks
def test_exports(binfile: IsleBin):
assert len(binfile.exports) == 130
assert (0x1003BFB0, b"??0LegoBackgroundColor@@QAE@PBD0@Z") in binfile.exports
assert (0x10091EE0, b"_DllMain@12") in binfile.exports

View file

@ -1,144 +0,0 @@
import pytest
from isledecomp.parser import DecompLinter
from isledecomp.parser.error import ParserError
@pytest.fixture(name="linter")
def fixture_linter():
return DecompLinter()
def test_simple_in_order(linter):
lines = [
"// FUNCTION: TEST 0x1000",
"void function1() {}",
"// FUNCTION: TEST 0x2000",
"void function2() {}",
"// FUNCTION: TEST 0x3000",
"void function3() {}",
]
assert linter.check_lines(lines, "test.cpp", "TEST") is True
def test_simple_not_in_order(linter):
lines = [
"// FUNCTION: TEST 0x1000",
"void function1() {}",
"// FUNCTION: TEST 0x3000",
"void function3() {}",
"// FUNCTION: TEST 0x2000",
"void function2() {}",
]
assert linter.check_lines(lines, "test.cpp", "TEST") is False
assert len(linter.alerts) == 1
assert linter.alerts[0].code == ParserError.FUNCTION_OUT_OF_ORDER
# N.B. Line number given is the start of the function, not the marker
assert linter.alerts[0].line_number == 6
def test_byname_ignored(linter):
"""Should ignore lookup-by-name markers when checking order."""
lines = [
"// FUNCTION: TEST 0x1000",
"void function1() {}",
"// FUNCTION: TEST 0x3000",
"// MyClass::MyMethod",
"// FUNCTION: TEST 0x2000",
"void function2() {}",
]
# This will fail because byname lookup does not belong in the cpp file
assert linter.check_lines(lines, "test.cpp", "TEST") is False
# but it should not fail for function order.
assert all(
alert.code != ParserError.FUNCTION_OUT_OF_ORDER for alert in linter.alerts
)
def test_module_isolation(linter):
"""Should check the order of markers from a single module only."""
lines = [
"// FUNCTION: ALPHA 0x0001",
"// FUNCTION: TEST 0x1000",
"void function1() {}",
"// FUNCTION: ALPHA 0x0002",
"// FUNCTION: TEST 0x2000",
"void function2() {}",
"// FUNCTION: ALPHA 0x0003",
"// FUNCTION: TEST 0x3000",
"void function3() {}",
]
assert linter.check_lines(lines, "test.cpp", "TEST") is True
linter.reset(True)
assert linter.check_lines(lines, "test.cpp", "ALPHA") is True
def test_byname_headers_only(linter):
"""Markers that ar referenced by name with cvdump belong in header files only."""
lines = [
"// FUNCTION: TEST 0x1000",
"// MyClass::~MyClass",
]
assert linter.check_lines(lines, "test.h", "TEST") is True
linter.reset(True)
assert linter.check_lines(lines, "test.cpp", "TEST") is False
assert linter.alerts[0].code == ParserError.BYNAME_FUNCTION_IN_CPP
def test_duplicate_offsets(linter):
"""The linter will retain module/offset pairs found until we do a full reset."""
lines = [
"// FUNCTION: TEST 0x1000",
"// FUNCTION: HELLO 0x1000",
"// MyClass::~MyClass",
]
# Should not fail for duplicate offset 0x1000 because the modules are unique.
assert linter.check_lines(lines, "test.h", "TEST") is True
# Simulate a failure by reading the same file twice.
assert linter.check_lines(lines, "test.h", "TEST") is False
# Two errors because offsets from both modules are duplicated
assert len(linter.alerts) == 2
assert all(a.code == ParserError.DUPLICATE_OFFSET for a in linter.alerts)
# Partial reset will retain the list of seen offsets.
linter.reset(False)
assert linter.check_lines(lines, "test.h", "TEST") is False
# Full reset will forget seen offsets.
linter.reset(True)
assert linter.check_lines(lines, "test.h", "TEST") is True
def test_duplicate_strings(linter):
"""Duplicate string markers are okay if the string value is the same."""
string_lines = [
"// STRING: TEST 0x1000",
'return "hello world";',
]
# No problem to use this marker twice.
assert linter.check_lines(string_lines, "test.h", "TEST") is True
assert linter.check_lines(string_lines, "test.h", "TEST") is True
different_string = [
"// STRING: TEST 0x1000",
'return "hi there";',
]
# Same address but the string is different
assert linter.check_lines(different_string, "greeting.h", "TEST") is False
assert len(linter.alerts) == 1
assert linter.alerts[0].code == ParserError.WRONG_STRING
same_addr_reused = [
"// GLOBAL:TEXT 0x1000",
"int g_test = 123;",
]
# This will fail like any other offset reuse.
assert linter.check_lines(same_addr_reused, "other.h", "TEST") is False

View file

@ -1,773 +0,0 @@
import pytest
from isledecomp.parser.parser import (
ReaderState,
DecompParser,
)
from isledecomp.parser.error import ParserError
@pytest.fixture(name="parser")
def fixture_parser():
return DecompParser()
def test_missing_sig(parser):
"""In the hopefully rare scenario that the function signature and marker
are swapped, we still have enough to match witch reccmp"""
parser.read_lines(
[
"void my_function()",
"// FUNCTION: TEST 0x1234",
"{",
"}",
]
)
assert parser.state == ReaderState.SEARCH
assert len(parser.functions) == 1
assert parser.functions[0].line_number == 3
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.MISSED_START_OF_FUNCTION
def test_not_exact_syntax(parser):
"""Alert to inexact syntax right here in the parser instead of kicking it downstream.
Doing this means we don't have to save the actual text."""
parser.read_line("// function: test 0x1234")
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.BAD_DECOMP_MARKER
def test_invalid_marker(parser):
"""We matched a decomp marker, but it's not one we care about"""
parser.read_line("// BANANA: TEST 0x1234")
assert parser.state == ReaderState.SEARCH
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.BOGUS_MARKER
def test_incompatible_marker(parser):
"""The marker we just read cannot be handled in the current parser state"""
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
"// GLOBAL: TEST 0x5000",
]
)
assert parser.state == ReaderState.SEARCH
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.INCOMPATIBLE_MARKER
def test_variable(parser):
"""Should identify a global variable"""
parser.read_lines(
[
"// GLOBAL: HELLO 0x1234",
"int g_value = 5;",
]
)
assert len(parser.variables) == 1
def test_synthetic_plus_marker(parser):
"""Marker tracking preempts synthetic name detection.
Should fail with error and not log the synthetic"""
parser.read_lines(
[
"// SYNTHETIC: HEY 0x555",
"// FUNCTION: HOWDY 0x1234",
]
)
assert len(parser.functions) == 0
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.INCOMPATIBLE_MARKER
def test_different_markers_different_module(parser):
"""Does it make any sense for a function to be a stub in one module,
but not in another? I don't know. But it's no problem for us."""
parser.read_lines(
[
"// FUNCTION: HOWDY 0x1234",
"// STUB: SUP 0x5555",
"void interesting_function() {",
"}",
]
)
assert len(parser.alerts) == 0
assert len(parser.functions) == 2
def test_different_markers_same_module(parser):
"""Now, if something is a regular function but then a stub,
what do we say about that?"""
parser.read_lines(
[
"// FUNCTION: HOWDY 0x1234",
"// STUB: HOWDY 0x5555",
"void interesting_function() {",
"}",
]
)
# Use first marker declaration, don't replace
assert len(parser.functions) == 1
assert parser.functions[0].should_skip() is False
# Should alert to this
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.DUPLICATE_MODULE
def test_unexpected_synthetic(parser):
"""FUNCTION then SYNTHETIC should fail to report either one"""
parser.read_lines(
[
"// FUNCTION: HOWDY 0x1234",
"// SYNTHETIC: HOWDY 0x5555",
"void interesting_function() {",
"}",
]
)
assert parser.state == ReaderState.SEARCH
assert len(parser.functions) == 0
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.INCOMPATIBLE_MARKER
@pytest.mark.skip(reason="not implemented yet")
def test_duplicate_offset(parser):
"""Repeating the same module/offset in the same file is probably a typo"""
parser.read_lines(
[
"// GLOBAL: HELLO 0x1234",
"int x = 1;",
"// GLOBAL: HELLO 0x1234",
"int y = 2;",
]
)
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.DUPLICATE_OFFSET
def test_multiple_variables(parser):
"""Theoretically the same global variable can appear in multiple modules"""
parser.read_lines(
[
"// GLOBAL: HELLO 0x1234",
"// GLOBAL: WUZZUP 0x555",
"const char *g_greeting;",
]
)
assert len(parser.alerts) == 0
assert len(parser.variables) == 2
def test_multiple_variables_same_module(parser):
"""Should not overwrite offset"""
parser.read_lines(
[
"// GLOBAL: HELLO 0x1234",
"// GLOBAL: HELLO 0x555",
"const char *g_greeting;",
]
)
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.DUPLICATE_MODULE
assert len(parser.variables) == 1
assert parser.variables[0].offset == 0x1234
def test_multiple_vtables(parser):
parser.read_lines(
[
"// VTABLE: HELLO 0x1234",
"// VTABLE: TEST 0x5432",
"class MxString : public MxCore {",
]
)
assert len(parser.alerts) == 0
assert len(parser.vtables) == 2
assert parser.vtables[0].name == "MxString"
def test_multiple_vtables_same_module(parser):
"""Should not overwrite offset"""
parser.read_lines(
[
"// VTABLE: HELLO 0x1234",
"// VTABLE: HELLO 0x5432",
"class MxString : public MxCore {",
]
)
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.DUPLICATE_MODULE
assert len(parser.vtables) == 1
assert parser.vtables[0].offset == 0x1234
def test_synthetic(parser):
parser.read_lines(
[
"// SYNTHETIC: TEST 0x1234",
"// TestClass::TestMethod",
]
)
assert len(parser.functions) == 1
assert parser.functions[0].lookup_by_name is True
assert parser.functions[0].name == "TestClass::TestMethod"
def test_synthetic_same_module(parser):
parser.read_lines(
[
"// SYNTHETIC: TEST 0x1234",
"// SYNTHETIC: TEST 0x555",
"// TestClass::TestMethod",
]
)
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.DUPLICATE_MODULE
assert len(parser.functions) == 1
assert parser.functions[0].offset == 0x1234
def test_synthetic_no_comment(parser):
"""Synthetic marker followed by a code line (i.e. non-comment)"""
parser.read_lines(
[
"// SYNTHETIC: TEST 0x1234",
"int x = 123;",
]
)
assert len(parser.functions) == 0
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.BAD_NAMEREF
assert parser.state == ReaderState.SEARCH
def test_single_line_function(parser):
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
"int hello() { return 1234; }",
]
)
assert len(parser.functions) == 1
assert parser.functions[0].line_number == 2
assert parser.functions[0].end_line == 2
def test_indented_function(parser):
"""Track the number of whitespace characters when we begin the function
and check that against each closing curly brace we read.
Should not report a syntax warning if the function is indented"""
parser.read_lines(
[
" // FUNCTION: TEST 0x1234",
" void indented()",
" {",
" // TODO",
" }",
" // FUNCTION: NEXT 0x555",
]
)
assert len(parser.alerts) == 0
@pytest.mark.xfail(reason="todo")
def test_indented_no_curly_hint(parser):
"""Same as above, but opening curly brace is on the same line.
Without the hint of how many whitespace characters to check, can we
still identify the end of the function?"""
parser.read_lines(
[
" // FUNCTION: TEST 0x1234",
" void indented() {",
" }",
" // FUNCTION: NEXT 0x555",
]
)
assert len(parser.alerts) == 0
def test_implicit_lookup_by_name(parser):
"""FUNCTION (or STUB) offsets must directly precede the function signature.
If we detect a comment instead, we assume that this is a lookup-by-name
function and end here."""
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
"// TestClass::TestMethod()",
]
)
assert parser.state == ReaderState.SEARCH
assert len(parser.functions) == 1
assert parser.functions[0].lookup_by_name is True
assert parser.functions[0].name == "TestClass::TestMethod()"
def test_function_with_spaces(parser):
"""There should not be any spaces between the end of FUNCTION markers
and the start or name of the function. If it's a blank line, we can safely
ignore but should alert to this."""
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
" ",
"inline void test_function() { };",
]
)
assert len(parser.functions) == 1
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.UNEXPECTED_BLANK_LINE
def test_function_with_spaces_implicit(parser):
"""Same as above, but for implicit lookup-by-name"""
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
" ",
"// Implicit::Method",
]
)
assert len(parser.functions) == 1
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.UNEXPECTED_BLANK_LINE
@pytest.mark.xfail(reason="will assume implicit lookup-by-name function")
def test_function_is_commented(parser):
"""In an ideal world, we would recognize that there is no code here.
Some editors (or users) might comment the function on each line like this
but hopefully it is rare."""
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
"// int my_function()",
"// {",
"// return 5;",
"// }",
]
)
assert len(parser.functions) == 0
def test_unexpected_eof(parser):
"""If a decomp marker finds its way to the last line of the file,
report that we could not get anything from it."""
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
"// Cls::Method",
"// FUNCTION: TEST 0x5555",
]
)
parser.finish()
assert len(parser.functions) == 1
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.UNEXPECTED_END_OF_FILE
@pytest.mark.xfail(reason="no longer applies")
def test_global_variable_prefix(parser):
"""Global and static variables should have the g_ prefix."""
parser.read_lines(
[
"// GLOBAL: TEST 0x1234",
'const char* g_msg = "hello";',
]
)
assert len(parser.variables) == 1
assert len(parser.alerts) == 0
parser.read_lines(
[
"// GLOBAL: TEXT 0x5555",
"int test = 5;",
]
)
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.GLOBAL_MISSING_PREFIX
# In spite of that, we should still grab the variable name.
assert parser.variables[1].name == "test"
def test_global_nomatch(parser):
"""We do our best to grab the variable name, even without the g_ prefix
but this (by design) will not match everything."""
parser.read_lines(
[
"// GLOBAL: TEST 0x1234",
"FunctionCall();",
]
)
assert len(parser.variables) == 0
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.NO_SUITABLE_NAME
def test_static_variable(parser):
"""We can detect whether a variable is a static function variable
based on the parser's state when we detect it.
Checking for the word `static` alone is not a good test.
Static class variables are filed as S_GDATA32, same as regular globals."""
parser.read_lines(
[
"// GLOBAL: TEST 0x1234",
"int g_test = 1234;",
]
)
assert len(parser.variables) == 1
assert parser.variables[0].is_static is False
parser.read_lines(
[
"// FUNCTION: TEST 0x5555",
"void test_function() {",
"// GLOBAL: TEST 0x8888",
"static int g_internal = 0;",
"}",
]
)
assert len(parser.variables) == 2
assert parser.variables[1].is_static is True
def test_reject_global_return(parser):
"""Previously we had annotated strings with the GLOBAL marker.
For example: if a function returned a string. We now want these to be
annotated with the STRING marker."""
parser.read_lines(
[
"// FUNCTION: TEST 0x5555",
"void test_function() {",
" // GLOBAL: TEST 0x8888",
' return "test";',
"}",
]
)
assert len(parser.variables) == 0
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.GLOBAL_NOT_VARIABLE
def test_global_string(parser):
"""We now allow GLOBAL and STRING markers for the same item."""
parser.read_lines(
[
"// GLOBAL: TEST 0x1234",
"// STRING: TEXT 0x5555",
'char* g_test = "hello";',
]
)
assert len(parser.variables) == 1
assert len(parser.strings) == 1
assert len(parser.alerts) == 0
assert parser.variables[0].name == "g_test"
assert parser.strings[0].name == "hello"
def test_comment_variables(parser):
"""Match on hidden variables from libraries."""
parser.read_lines(
[
"// GLOBAL: TEST 0x1234",
"// g_test",
]
)
assert len(parser.variables) == 1
assert parser.variables[0].name == "g_test"
def test_flexible_variable_prefix(parser):
"""Don't alert to library variables that lack the g_ prefix.
This is out of our control."""
parser.read_lines(
[
"// GLOBAL: TEST 0x1234",
"// some_other_variable",
]
)
assert len(parser.variables) == 1
assert len(parser.alerts) == 0
assert parser.variables[0].name == "some_other_variable"
def test_string_ignore_g_prefix(parser):
"""String annotations above a regular variable should not alert to
the missing g_ prefix. This is only required for GLOBAL markers."""
parser.read_lines(
[
"// STRING: TEST 0x1234",
'const char* value = "";',
]
)
assert len(parser.strings) == 1
assert len(parser.alerts) == 0
def test_class_variable(parser):
"""We should accurately name static variables that are class members."""
parser.read_lines(
[
"class Test {",
"protected:",
" // GLOBAL: TEST 0x1234",
" static int g_test;",
"};",
]
)
assert len(parser.variables) == 1
assert parser.variables[0].name == "Test::g_test"
def test_namespace_variable(parser):
"""We should identify a namespace surrounding any global variables"""
parser.read_lines(
[
"namespace Test {",
"// GLOBAL: TEST 0x1234",
"int g_test = 1234;",
"}",
"// GLOBAL: TEST 0x5555",
"int g_second = 2;",
]
)
assert len(parser.variables) == 2
assert parser.variables[0].name == "Test::g_test"
assert parser.variables[1].name == "g_second"
def test_namespace_vtable(parser):
parser.read_lines(
[
"namespace Tgl {",
"// VTABLE: TEST 0x1234",
"class Renderer {",
"};",
"}",
"// VTABLE: TEST 0x5555",
"class Hello { };",
]
)
assert len(parser.vtables) == 2
assert parser.vtables[0].name == "Tgl::Renderer"
assert parser.vtables[1].name == "Hello"
@pytest.mark.xfail(reason="no longer applies")
def test_global_prefix_namespace(parser):
"""Should correctly identify namespaces before checking for the g_ prefix"""
parser.read_lines(
[
"class Test {",
" // GLOBAL: TEST 0x1234",
" static int g_count = 0;",
" // GLOBAL: TEST 0x5555",
" static int count = 0;",
"};",
]
)
assert len(parser.variables) == 2
assert parser.variables[0].name == "Test::g_count"
assert parser.variables[1].name == "Test::count"
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.GLOBAL_MISSING_PREFIX
def test_nested_namespace(parser):
parser.read_lines(
[
"namespace Tgl {",
"class Renderer {",
" // GLOBAL: TEST 0x1234",
" static int g_count = 0;",
"};",
"};",
]
)
assert len(parser.variables) == 1
assert parser.variables[0].name == "Tgl::Renderer::g_count"
def test_match_qualified_variable(parser):
"""If a variable belongs to a scope and we use a fully qualified reference
below a GLOBAL marker, make sure we capture the full name."""
parser.read_lines(
[
"// GLOBAL: TEST 0x1234",
"int MxTest::g_count = 0;",
]
)
assert len(parser.variables) == 1
assert parser.variables[0].name == "MxTest::g_count"
assert len(parser.alerts) == 0
def test_static_variable_parent(parser):
"""Report the address of the parent function that contains a static variable."""
parser.read_lines(
[
"// FUNCTION: TEST 0x1234",
"void test()",
"{",
" // GLOBAL: TEST 0x5555",
" static int g_count = 0;",
"}",
]
)
assert len(parser.variables) == 1
assert parser.variables[0].is_static is True
assert parser.variables[0].parent_function == 0x1234
@pytest.mark.xfail(
reason="""Without the FUNCTION marker we don't know that we are inside a function,
so we do not identify this variable as static."""
)
def test_static_variable_no_parent(parser):
"""If the function that contains a static variable is not marked, we
cannot match it with cvdump so we should skip it and report an error."""
parser.read_lines(
[
"void test()",
"{",
" // GLOBAL: TEST 0x5555",
" static int g_count = 0;",
"}",
]
)
# No way to match this variable so don't report it
assert len(parser.variables) == 0
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.ORPHANED_STATIC_VARIABLE
def test_static_variable_incomplete_coverage(parser):
"""If the function that contains a static variable is marked, but
not for each module used for the variable itself, this is an error."""
parser.read_lines(
[
"// FUNCTION: HELLO 0x1234",
"void test()",
"{",
" // GLOBAL: HELLO 0x5555",
" // GLOBAL: TEST 0x5555",
" static int g_count = 0;",
"}",
]
)
# Match for HELLO module
assert len(parser.variables) == 1
# Failed for TEST module
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.ORPHANED_STATIC_VARIABLE
def test_header_function_declaration(parser):
"""This is either a forward reference or a declaration in a header file.
Meaning: The implementation is not here. This is not the correct place
for the FUNCTION marker and it will probably not match anything."""
parser.read_lines(
[
"// FUNCTION: HELLO 0x1234",
"void sample_function(int);",
]
)
assert len(parser.alerts) == 1
assert parser.alerts[0].code == ParserError.NO_IMPLEMENTATION
def test_extra(parser):
"""Allow a fourth field in the decomp annotation. Its use will vary
depending on the marker type. Currently this is only used to identify
a vtable with virtual inheritance."""
# Intentionally using non-vtable markers here.
# We might want to emit a parser warning for unnecessary extra info.
parser.read_lines(
[
"// GLOBAL: TEST 0x5555 Haha",
"int g_variable = 0;",
"// FUNCTION: TEST 0x1234 Something",
"void Test() { g_variable++; }",
"// LIBRARY: TEST 0x8080 Printf",
"// _printf",
]
)
# We don't use this information (yet) but this is all fine.
assert len(parser.alerts) == 0
def test_virtual_inheritance(parser):
"""Indicate the base class for a vtable where the class uses
virtual inheritance."""
parser.read_lines(
[
"// VTABLE: HELLO 0x1234",
"// VTABLE: HELLO 0x1238 Greetings",
"// VTABLE: HELLO 0x123c Howdy",
"class HiThere : public virtual Greetings {",
"};",
]
)
assert len(parser.alerts) == 0
assert len(parser.vtables) == 3
assert parser.vtables[0].base_class is None
assert parser.vtables[1].base_class == "Greetings"
assert parser.vtables[2].base_class == "Howdy"
assert all(v.name == "HiThere" for v in parser.vtables)
def test_namespace_in_comment(parser):
parser.read_lines(
[
"// VTABLE: HELLO 0x1234",
"// class Tgl::Object",
"// VTABLE: HELLO 0x5555",
"// class TglImpl::RendererImpl<D3DRMImpl::D3DRM>",
]
)
assert len(parser.vtables) == 2
assert parser.vtables[0].name == "Tgl::Object"
assert parser.vtables[1].name == "TglImpl::RendererImpl<D3DRMImpl::D3DRM>"

View file

@ -1,141 +0,0 @@
import os
from typing import List, TextIO
import pytest
from isledecomp.parser import DecompParser
from isledecomp.parser.node import ParserSymbol
SAMPLE_DIR = os.path.join(os.path.dirname(__file__), "samples")
def sample_file(filename: str) -> TextIO:
"""Wrapper for opening the samples from the directory that does not
depend on the cwd where we run the test"""
full_path = os.path.join(SAMPLE_DIR, filename)
return open(full_path, "r", encoding="utf-8")
def code_blocks_are_sorted(blocks: List[ParserSymbol]) -> bool:
"""Helper to make this more idiomatic"""
just_offsets = [block.offset for block in blocks]
return just_offsets == sorted(just_offsets)
@pytest.fixture(name="parser")
def fixture_parser():
return DecompParser()
# Tests are below #
def test_sanity(parser):
"""Read a very basic file"""
with sample_file("basic_file.cpp") as f:
parser.read_lines(f)
assert len(parser.functions) == 3
assert code_blocks_are_sorted(parser.functions) is True
# n.b. The parser returns line numbers as 1-based
# Function starts when we see the opening curly brace
assert parser.functions[0].line_number == 8
assert parser.functions[0].end_line == 10
def test_oneline(parser):
"""(Assuming clang-format permits this) This sample has a function
on a single line. This will test the end-of-function detection"""
with sample_file("oneline_function.cpp") as f:
parser.read_lines(f)
assert len(parser.functions) == 2
assert parser.functions[0].line_number == 5
assert parser.functions[0].end_line == 5
def test_missing_offset(parser):
"""What if the function doesn't have an offset comment?"""
with sample_file("missing_offset.cpp") as f:
parser.read_lines(f)
# TODO: For now, the function without the offset will just be ignored.
# Would be the same outcome if the comment was present but mangled and
# we failed to match it. We should detect these cases in the future.
assert len(parser.functions) == 1
def test_jumbled_case(parser):
"""The parser just reports what it sees. It is the responsibility of
the downstream tools to do something about a jumbled file.
Just verify that we are reading it correctly."""
with sample_file("out_of_order.cpp") as f:
parser.read_lines(f)
assert len(parser.functions) == 3
assert code_blocks_are_sorted(parser.functions) is False
def test_bad_file(parser):
with sample_file("poorly_formatted.cpp") as f:
parser.read_lines(f)
assert len(parser.functions) == 3
def test_indented(parser):
"""Offsets for functions inside of a class will probably be indented."""
with sample_file("basic_class.cpp") as f:
parser.read_lines(f)
# TODO: We don't properly detect the end of these functions
# because the closing brace is indented. However... knowing where each
# function ends is less important (for now) than capturing
# all the functions that are there.
assert len(parser.functions) == 2
assert parser.functions[0].offset == int("0x12345678", 16)
assert parser.functions[0].line_number == 16
# assert parser.functions[0].end_line == 19
assert parser.functions[1].offset == int("0xdeadbeef", 16)
assert parser.functions[1].line_number == 23
# assert parser.functions[1].end_line == 25
def test_inline(parser):
with sample_file("inline.cpp") as f:
parser.read_lines(f)
assert len(parser.functions) == 2
for fun in parser.functions:
assert fun.line_number is not None
assert fun.line_number == fun.end_line
def test_multiple_offsets(parser):
"""If multiple offset marks appear before for a code block, take them
all but ensure module name (case-insensitive) is distinct.
Use first module occurrence in case of duplicates."""
with sample_file("multiple_offsets.cpp") as f:
parser.read_lines(f)
assert len(parser.functions) == 4
assert parser.functions[0].module == "TEST"
assert parser.functions[0].line_number == 9
assert parser.functions[1].module == "HELLO"
assert parser.functions[1].line_number == 9
# Duplicate modules are ignored
assert parser.functions[2].line_number == 16
assert parser.functions[2].offset == 0x2345
assert parser.functions[3].module == "TEST"
assert parser.functions[3].offset == 0x2002
def test_variables(parser):
with sample_file("global_variables.cpp") as f:
parser.read_lines(f)
assert len(parser.functions) == 1
assert len(parser.variables) == 2

View file

@ -1,141 +0,0 @@
from typing import Optional
import pytest
from isledecomp.parser.parser import (
ReaderState as _rs,
DecompParser,
)
from isledecomp.parser.error import ParserError as _pe
# fmt: off
state_change_marker_cases = [
(_rs.SEARCH, "FUNCTION", _rs.WANT_SIG, None),
(_rs.SEARCH, "GLOBAL", _rs.IN_GLOBAL, None),
(_rs.SEARCH, "STUB", _rs.WANT_SIG, None),
(_rs.SEARCH, "SYNTHETIC", _rs.IN_SYNTHETIC, None),
(_rs.SEARCH, "TEMPLATE", _rs.IN_TEMPLATE, None),
(_rs.SEARCH, "VTABLE", _rs.IN_VTABLE, None),
(_rs.SEARCH, "LIBRARY", _rs.IN_LIBRARY, None),
(_rs.SEARCH, "STRING", _rs.IN_GLOBAL, None),
(_rs.WANT_SIG, "FUNCTION", _rs.WANT_SIG, None),
(_rs.WANT_SIG, "GLOBAL", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.WANT_SIG, "STUB", _rs.WANT_SIG, None),
(_rs.WANT_SIG, "SYNTHETIC", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.WANT_SIG, "TEMPLATE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.WANT_SIG, "VTABLE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.WANT_SIG, "LIBRARY", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.WANT_SIG, "STRING", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_FUNC, "FUNCTION", _rs.WANT_SIG, _pe.MISSED_END_OF_FUNCTION),
(_rs.IN_FUNC, "GLOBAL", _rs.IN_FUNC_GLOBAL, None),
(_rs.IN_FUNC, "STUB", _rs.WANT_SIG, _pe.MISSED_END_OF_FUNCTION),
(_rs.IN_FUNC, "SYNTHETIC", _rs.IN_SYNTHETIC, _pe.MISSED_END_OF_FUNCTION),
(_rs.IN_FUNC, "TEMPLATE", _rs.IN_TEMPLATE, _pe.MISSED_END_OF_FUNCTION),
(_rs.IN_FUNC, "VTABLE", _rs.IN_VTABLE, _pe.MISSED_END_OF_FUNCTION),
(_rs.IN_FUNC, "LIBRARY", _rs.IN_LIBRARY, _pe.MISSED_END_OF_FUNCTION),
(_rs.IN_FUNC, "STRING", _rs.IN_FUNC_GLOBAL, None),
(_rs.IN_TEMPLATE, "FUNCTION", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_TEMPLATE, "GLOBAL", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_TEMPLATE, "STUB", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_TEMPLATE, "SYNTHETIC", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_TEMPLATE, "TEMPLATE", _rs.IN_TEMPLATE, None),
(_rs.IN_TEMPLATE, "VTABLE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_TEMPLATE, "LIBRARY", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_TEMPLATE, "STRING", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.WANT_CURLY, "FUNCTION", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.WANT_CURLY, "GLOBAL", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.WANT_CURLY, "STUB", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.WANT_CURLY, "SYNTHETIC", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.WANT_CURLY, "TEMPLATE", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.WANT_CURLY, "VTABLE", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.WANT_CURLY, "LIBRARY", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.WANT_CURLY, "STRING", _rs.SEARCH, _pe.UNEXPECTED_MARKER),
(_rs.IN_GLOBAL, "FUNCTION", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_GLOBAL, "GLOBAL", _rs.IN_GLOBAL, None),
(_rs.IN_GLOBAL, "STUB", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_GLOBAL, "SYNTHETIC", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_GLOBAL, "TEMPLATE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_GLOBAL, "VTABLE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_GLOBAL, "LIBRARY", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_GLOBAL, "STRING", _rs.IN_GLOBAL, None),
(_rs.IN_FUNC_GLOBAL, "FUNCTION", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_FUNC_GLOBAL, "GLOBAL", _rs.IN_FUNC_GLOBAL, None),
(_rs.IN_FUNC_GLOBAL, "STUB", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_FUNC_GLOBAL, "SYNTHETIC", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_FUNC_GLOBAL, "TEMPLATE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_FUNC_GLOBAL, "VTABLE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_FUNC_GLOBAL, "LIBRARY", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_FUNC_GLOBAL, "STRING", _rs.IN_FUNC_GLOBAL, None),
(_rs.IN_VTABLE, "FUNCTION", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_VTABLE, "GLOBAL", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_VTABLE, "STUB", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_VTABLE, "SYNTHETIC", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_VTABLE, "TEMPLATE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_VTABLE, "VTABLE", _rs.IN_VTABLE, None),
(_rs.IN_VTABLE, "LIBRARY", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_VTABLE, "STRING", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_SYNTHETIC, "FUNCTION", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_SYNTHETIC, "GLOBAL", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_SYNTHETIC, "STUB", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_SYNTHETIC, "SYNTHETIC", _rs.IN_SYNTHETIC, None),
(_rs.IN_SYNTHETIC, "TEMPLATE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_SYNTHETIC, "VTABLE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_SYNTHETIC, "LIBRARY", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_SYNTHETIC, "STRING", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_LIBRARY, "FUNCTION", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_LIBRARY, "GLOBAL", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_LIBRARY, "STUB", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_LIBRARY, "SYNTHETIC", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_LIBRARY, "TEMPLATE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_LIBRARY, "VTABLE", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
(_rs.IN_LIBRARY, "LIBRARY", _rs.IN_LIBRARY, None),
(_rs.IN_LIBRARY, "STRING", _rs.SEARCH, _pe.INCOMPATIBLE_MARKER),
]
# fmt: on
@pytest.mark.parametrize(
"state, marker_type, new_state, expected_error", state_change_marker_cases
)
def test_state_change_by_marker(
state: _rs, marker_type: str, new_state: _rs, expected_error: Optional[_pe]
):
p = DecompParser()
p.state = state
mock_line = f"// {marker_type}: TEST 0x1234"
p.read_line(mock_line)
assert p.state == new_state
if expected_error is not None:
assert len(p.alerts) > 0
assert p.alerts[0].code == expected_error
# Reading any of these lines should have no effect in ReaderState.SEARCH
search_lines_no_effect = [
"",
"\t",
" ",
"int x = 0;",
"// Comment",
"/*",
"*/",
"/* Block comment */",
"{",
"}",
]
@pytest.mark.parametrize("line", search_lines_no_effect)
def test_state_search_line(line: str):
p = DecompParser()
p.read_line(line)
assert p.state == _rs.SEARCH
assert len(p.alerts) == 0

View file

@ -1,209 +0,0 @@
import pytest
from isledecomp.parser.parser import MarkerDict
from isledecomp.parser.marker import (
DecompMarker,
MarkerType,
match_marker,
is_marker_exact,
)
from isledecomp.parser.util import (
is_blank_or_comment,
get_class_name,
get_variable_name,
get_string_contents,
)
blank_or_comment_param = [
(True, ""),
(True, "\t"),
(True, " "),
(False, "\tint abc=123;"),
(True, "// OFFSET: LEGO1 0xdeadbeef"),
(True, " /* Block comment beginning"),
(True, "Block comment ending */ "),
# TODO: does clang-format have anything to say about these cases?
(False, "x++; // Comment folows"),
(False, "x++; /* Block comment begins"),
]
@pytest.mark.parametrize("expected, line", blank_or_comment_param)
def test_is_blank_or_comment(line: str, expected: bool):
assert is_blank_or_comment(line) is expected
marker_samples = [
# (can_parse: bool, exact_match: bool, line: str)
(True, True, "// FUNCTION: LEGO1 0xdeadbeef"),
(True, True, "// FUNCTION: ISLE 0x12345678"),
# No trailing spaces allowed
(True, False, "// FUNCTION: LEGO1 0xdeadbeef "),
# Must have exactly one space between elements
(True, False, "//FUNCTION: ISLE 0xdeadbeef"),
(True, False, "// FUNCTION:ISLE 0xdeadbeef"),
(True, False, "// FUNCTION: ISLE 0xdeadbeef"),
(True, False, "// FUNCTION: ISLE 0xdeadbeef"),
(True, False, "// FUNCTION: ISLE 0xdeadbeef"),
# Must have 0x prefix for hex number to match at all
(False, False, "// FUNCTION: ISLE deadbeef"),
# Offset, module name, and STUB must be uppercase
(True, False, "// function: ISLE 0xdeadbeef"),
(True, False, "// function: isle 0xdeadbeef"),
# Hex string must be lowercase
(True, False, "// FUNCTION: ISLE 0xDEADBEEF"),
# TODO: How flexible should we be with matching the module name?
(True, True, "// FUNCTION: OMNI 0x12345678"),
(True, True, "// FUNCTION: LEG01 0x12345678"),
(True, False, "// FUNCTION: hello 0x12345678"),
# Not close enough to match
(False, False, "// FUNCTION: ISLE0x12345678"),
(False, False, "// FUNCTION: 0x12345678"),
(False, False, "// LEGO1: 0x12345678"),
# Hex string shorter than 8 characters
(True, True, "// FUNCTION: LEGO1 0x1234"),
# TODO: These match but shouldn't.
# (False, False, '// FUNCTION: LEGO1 0'),
# (False, False, '// FUNCTION: LEGO1 0x'),
# Extra field
(True, True, "// VTABLE: HELLO 0x1234 Extra"),
# Extra with spaces
(True, True, "// VTABLE: HELLO 0x1234 Whatever<SubClass *>"),
# Extra, no space (if the first non-hex character is not in [a-f])
(True, False, "// VTABLE: HELLO 0x1234Hello"),
# Extra, many spaces
(True, False, "// VTABLE: HELLO 0x1234 Hello"),
]
@pytest.mark.parametrize("match, _, line", marker_samples)
def test_marker_match(line: str, match: bool, _):
did_match = match_marker(line) is not None
assert did_match is match
@pytest.mark.parametrize("_, exact, line", marker_samples)
def test_marker_exact(line: str, exact: bool, _):
assert is_marker_exact(line) is exact
def test_marker_dict_simple():
d = MarkerDict()
d.insert(DecompMarker("FUNCTION", "TEST", 0x1234))
markers = list(d.iter())
assert len(markers) == 1
def test_marker_dict_ofs_replace():
d = MarkerDict()
d.insert(DecompMarker("FUNCTION", "TEST", 0x1234))
d.insert(DecompMarker("FUNCTION", "TEST", 0x555))
markers = list(d.iter())
assert len(markers) == 1
assert markers[0].offset == 0x1234
def test_marker_dict_type_replace():
d = MarkerDict()
d.insert(DecompMarker("FUNCTION", "TEST", 0x1234))
d.insert(DecompMarker("STUB", "TEST", 0x1234))
markers = list(d.iter())
assert len(markers) == 1
assert markers[0].type == MarkerType.FUNCTION
class_name_match_cases = [
("struct MxString {", "MxString"),
("class MxString {", "MxString"),
("// class MxString", "MxString"),
("class MxString : public MxCore {", "MxString"),
("class MxPtrList<MxPresenter>", "MxPtrList<MxPresenter>"),
# If it is possible to match the symbol MxList<LegoPathController *>::`vftable'
# we should get the correct class name if possible. If the template type is a pointer,
# the asterisk and class name are separated by one space.
("// class MxList<LegoPathController *>", "MxList<LegoPathController *>"),
("// class MxList<LegoPathController*>", "MxList<LegoPathController *>"),
("// class MxList<LegoPathController* >", "MxList<LegoPathController *>"),
# I don't know if this would ever come up, but sure, why not?
("// class MxList<LegoPathController**>", "MxList<LegoPathController **>"),
("// class Many::Name::Spaces", "Many::Name::Spaces"),
]
@pytest.mark.parametrize("line, class_name", class_name_match_cases)
def test_get_class_name(line: str, class_name: str):
assert get_class_name(line) == class_name
class_name_no_match_cases = [
"MxString { ",
"clas MxString",
"// MxPtrList<MxPresenter>::`scalar deleting destructor'",
]
@pytest.mark.parametrize("line", class_name_no_match_cases)
def test_get_class_name_none(line: str):
assert get_class_name(line) is None
variable_name_cases = [
# with prefix for easy access
("char* g_test;", "g_test"),
("g_test;", "g_test"),
("void (*g_test)(int);", "g_test"),
("char g_test[50];", "g_test"),
("char g_test[50] = {1234,", "g_test"),
("int g_test = 500;", "g_test"),
# no prefix
("char* hello;", "hello"),
("hello;", "hello"),
("void (*hello)(int);", "hello"),
("char hello[50];", "hello"),
("char hello[50] = {1234,", "hello"),
("int hello = 500;", "hello"),
]
@pytest.mark.parametrize("line,name", variable_name_cases)
def test_get_variable_name(line: str, name: str):
assert get_variable_name(line) == name
string_match_cases = [
('return "hello world";', "hello world"),
('"hello\\\\"', "hello\\"),
('"hello \\"world\\""', 'hello "world"'),
('"hello\\nworld"', "hello\nworld"),
# Only match first string if there are multiple options
('Method("hello", "world");', "hello"),
]
@pytest.mark.parametrize("line, string", string_match_cases)
def test_get_string_contents(line: str, string: str):
assert get_string_contents(line) == string
def test_marker_extra_spaces():
"""The extra field can contain spaces"""
marker = match_marker("// VTABLE: TEST 0x1234 S p a c e s")
assert marker.extra == "S p a c e s"
# Trailing spaces removed
marker = match_marker("// VTABLE: TEST 0x8888 spaces ")
assert marker.extra == "spaces"
# Trailing newline removed if present
marker = match_marker("// VTABLE: TEST 0x5555 newline\n")
assert marker.extra == "newline"
def test_marker_trailing_spaces():
"""Should ignore trailing spaces. (Invalid extra field)
Offset field not truncated, extra field set to None."""
marker = match_marker("// VTABLE: TEST 0x1234 ")
assert marker is not None
assert marker.offset == 0x1234
assert marker.extra is None

View file

@ -1,32 +0,0 @@
from os import name as os_name
import pytest
from isledecomp.dir import PathResolver
if os_name != "nt":
pytest.skip(reason="Skip Windows-only tests", allow_module_level=True)
@pytest.fixture(name="resolver")
def fixture_resolver_win():
yield PathResolver("C:\\isle")
def test_identity(resolver):
assert resolver.resolve_cvdump("C:\\isle\\test.h") == "C:\\isle\\test.h"
def test_outside_basedir(resolver):
assert resolver.resolve_cvdump("C:\\lego\\test.h") == "C:\\lego\\test.h"
def test_relative(resolver):
assert resolver.resolve_cvdump(".\\test.h") == "C:\\isle\\test.h"
assert resolver.resolve_cvdump("..\\test.h") == "C:\\test.h"
def test_intermediate_relative(resolver):
"""These paths may not register as `relative` paths, but we want to
produce a single absolute path for each."""
assert resolver.resolve_cvdump("C:\\isle\\test\\..\\test.h") == "C:\\isle\\test.h"
assert resolver.resolve_cvdump(".\\subdir\\..\\test.h") == "C:\\isle\\test.h"

View file

@ -1,69 +0,0 @@
from os import name as os_name
from unittest.mock import patch
import pytest
from isledecomp.dir import PathResolver
if os_name == "nt":
pytest.skip(reason="Skip Posix-only tests", allow_module_level=True)
@pytest.fixture(name="resolver")
def fixture_resolver_posix():
# Skip the call to winepath by using a patch, although this is not strictly necessary.
with patch("isledecomp.dir.winepath_unix_to_win", return_value="Z:\\usr\\isle"):
yield PathResolver("/usr/isle")
@patch("isledecomp.dir.winepath_win_to_unix")
def test_identity(winepath_mock, resolver):
"""Test with an absolute Wine path where a path swap is possible."""
# In this and upcoming tests, patch is_file so we always assume there is
# a file at the given unix path. We want to test the conversion logic only.
with patch("pathlib.Path.is_file", return_value=True):
assert resolver.resolve_cvdump("Z:\\usr\\isle\\test.h") == "/usr/isle/test.h"
winepath_mock.assert_not_called()
# Without the patch, this should call the winepath_mock, but we have
# memoized the value from the previous run.
assert resolver.resolve_cvdump("Z:\\usr\\isle\\test.h") == "/usr/isle/test.h"
winepath_mock.assert_not_called()
@patch("isledecomp.dir.winepath_win_to_unix")
def test_file_does_not_exist(winepath_mock, resolver):
"""These test files (probably) don't exist, so we always assume
the path swap failed and defer to winepath."""
resolver.resolve_cvdump("Z:\\usr\\isle\\test.h")
winepath_mock.assert_called_once_with("Z:\\usr\\isle\\test.h")
@patch("isledecomp.dir.winepath_win_to_unix")
def test_outside_basedir(winepath_mock, resolver):
"""Test an absolute path where we cannot do a path swap."""
with patch("pathlib.Path.is_file", return_value=True):
resolver.resolve_cvdump("Z:\\lego\\test.h")
winepath_mock.assert_called_once_with("Z:\\lego\\test.h")
@patch("isledecomp.dir.winepath_win_to_unix")
def test_relative(winepath_mock, resolver):
"""Test relative paths inside and outside of the base dir."""
with patch("pathlib.Path.is_file", return_value=True):
assert resolver.resolve_cvdump("./test.h") == "/usr/isle/test.h"
# This works because we will resolve "/usr/isle/test/../test.h"
assert resolver.resolve_cvdump("../test.h") == "/usr/test.h"
winepath_mock.assert_not_called()
@patch("isledecomp.dir.winepath_win_to_unix")
def test_intermediate_relative(winepath_mock, resolver):
"""We can resolve intermediate backdirs if they are relative to the basedir."""
with patch("pathlib.Path.is_file", return_value=True):
assert (
resolver.resolve_cvdump("Z:\\usr\\isle\\test\\..\\test.h")
== "/usr/isle/test.h"
)
assert resolver.resolve_cvdump(".\\subdir\\..\\test.h") == "/usr/isle/test.h"
winepath_mock.assert_not_called()

View file

@ -1,296 +0,0 @@
from typing import Optional
import pytest
from isledecomp.compare.asm.parse import DisasmLiteInst, ParseAsm
def mock_inst(mnemonic: str, op_str: str) -> DisasmLiteInst:
"""Mock up the named tuple DisasmLite from just a mnemonic and op_str.
To be used for tests on sanitize that do not require the instruction address
or size. i.e. any non-jump instruction."""
return DisasmLiteInst(0, 0, mnemonic, op_str)
identity_cases = [
("", ""),
("sti", ""),
("push", "ebx"),
("ret", ""),
("ret", "4"),
("mov", "eax, 0x1234"),
]
@pytest.mark.parametrize("mnemonic, op_str", identity_cases)
def test_identity(mnemonic, op_str):
"""Confirm that nothing is substituted."""
p = ParseAsm()
inst = mock_inst(mnemonic, op_str)
result = p.sanitize(inst)
assert result == (mnemonic, op_str)
ptr_replace_cases = [
("byte ptr [0x5555]", "byte ptr [<OFFSET1>]"),
("word ptr [0x5555]", "word ptr [<OFFSET1>]"),
("dword ptr [0x5555]", "dword ptr [<OFFSET1>]"),
("qword ptr [0x5555]", "qword ptr [<OFFSET1>]"),
("eax, dword ptr [0x5555]", "eax, dword ptr [<OFFSET1>]"),
("dword ptr [0x5555], eax", "dword ptr [<OFFSET1>], eax"),
("dword ptr [0x5555], 0", "dword ptr [<OFFSET1>], 0"),
("dword ptr [0x5555], 8", "dword ptr [<OFFSET1>], 8"),
# Same value, assumed to be an addr in the first appearance
# because it is designated as 'ptr', but we have not provided the
# relocation table lookup method so we do not replace the second appearance.
("dword ptr [0x5555], 0x5555", "dword ptr [<OFFSET1>], 0x5555"),
]
@pytest.mark.parametrize("start, end", ptr_replace_cases)
def test_ptr_replace(start, end):
"""Anything in square brackets (with the 'ptr' prefix) will always be replaced."""
p = ParseAsm()
inst = mock_inst("", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
call_replace_cases = [
("ebx", "ebx"),
("0x1234", "<OFFSET1>"),
("dword ptr [0x1234]", "dword ptr [<OFFSET1>]"),
("dword ptr [ecx + 0x10]", "dword ptr [ecx + 0x10]"),
]
@pytest.mark.parametrize("start, end", call_replace_cases)
def test_call_replace(start, end):
"""Call with hex operand is always replaced.
Otherwise, ptr replacement rules apply, but skip `this` calls."""
p = ParseAsm()
inst = mock_inst("call", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
def test_jump_displacement():
"""Display jump displacement (offset from end of jump instruction)
instead of destination address."""
p = ParseAsm()
inst = DisasmLiteInst(0x1000, 2, "je", "0x1000")
(_, op_str) = p.sanitize(inst)
assert op_str == "-0x2"
def test_jmp_table():
"""To ignore cases where it would be inappropriate to replace pointer
displacement (i.e. the vast majority of them) we require the address
to be relocated. This excludes any address less than the imagebase."""
p = ParseAsm()
inst = mock_inst("jmp", "dword ptr [eax*4 + 0x5555]")
(_, op_str) = p.sanitize(inst)
# i.e. no change
assert op_str == "dword ptr [eax*4 + 0x5555]"
def relocate_lookup(addr: int) -> bool:
return addr == 0x5555
# Now add the relocation lookup
p = ParseAsm(relocate_lookup=relocate_lookup)
(_, op_str) = p.sanitize(inst)
# Should replace it now
assert op_str == "dword ptr [eax*4 + <OFFSET1>]"
name_replace_cases = [
("byte ptr [0x5555]", "byte ptr [_substitute_]"),
("word ptr [0x5555]", "word ptr [_substitute_]"),
("dword ptr [0x5555]", "dword ptr [_substitute_]"),
("qword ptr [0x5555]", "qword ptr [_substitute_]"),
]
@pytest.mark.parametrize("start, end", name_replace_cases)
def test_name_replace(start, end):
"""Make sure the name lookup function is called if present"""
def substitute(_: int, __: bool) -> str:
return "_substitute_"
p = ParseAsm(name_lookup=substitute)
inst = mock_inst("mov", start)
(_, op_str) = p.sanitize(inst)
assert op_str == end
def test_replacement_cache():
p = ParseAsm()
inst = mock_inst("inc", "dword ptr [0x1234]")
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [<OFFSET1>]"
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [<OFFSET1>]"
def test_replacement_numbering():
"""If we can use the name lookup for the first address but not the second,
the second replacement should be <OFFSET2> not <OFFSET1>."""
def substitute_1234(addr: int, _: bool) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
(_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x1234]"))
assert op_str == "dword ptr [_substitute_]"
(_, op_str) = p.sanitize(mock_inst("inc", "dword ptr [0x5555]"))
assert op_str == "dword ptr [<OFFSET2>]"
def test_relocate_lookup():
"""Immediate values would be relocated if they are actually addresses.
So we can use the relocation table to check whether a given value is an
address or just some number."""
def relocate_lookup(addr: int) -> bool:
return addr == 0x1234
p = ParseAsm(relocate_lookup=relocate_lookup)
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x1234"))
assert op_str == "eax, <OFFSET1>"
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x5555"))
assert op_str == "eax, 0x5555"
def test_jump_to_function():
"""A jmp instruction can lead us directly to a function. This can be found
in the unwind section at the end of a function. However: we do not want to
assume this is the case for all jumps. Only replace the jump with a name
if we can find it using our lookup."""
def substitute_1234(addr: int, _: bool) -> Optional[str]:
return "_substitute_" if addr == 0x1234 else None
p = ParseAsm(name_lookup=substitute_1234)
inst = DisasmLiteInst(0x1000, 2, "jmp", "0x1234")
(_, op_str) = p.sanitize(inst)
assert op_str == "_substitute_"
# Should not replace this jump.
# 0x1000 (start addr)
# + 2 (size of jump instruction)
# + 0x5555 (displacement, the value we want)
# = 0x6557
inst = DisasmLiteInst(0x1000, 2, "jmp", "0x6557")
(_, op_str) = p.sanitize(inst)
assert op_str == "0x5555"
@pytest.mark.skip(reason="changed implementation")
def test_float_replacement():
"""Floating point constants often appear as pointers to data.
A good example is ViewROI::IntrinsicImportance and the subclass override
LegoROI::IntrinsicImportance. Both return 0.5, but this is done via the
FLD instruction and a dword value at 0x100dbdec. In this case it is more
valuable to just read the constant value rather than use a placeholder.
The float constants don't appear to be deduplicated (like strings are)
because there is another 0.5 at 0x100d40b0."""
def bin_lookup(addr: int, _: int) -> Optional[bytes]:
return b"\xdb\x0f\x49\x40" if addr == 0x1234 else None
p = ParseAsm(bin_lookup=bin_lookup)
inst = DisasmLiteInst(0x1000, 6, "fld", "dword ptr [0x1234]")
(_, op_str) = p.sanitize(inst)
# Single-precision float. struct.unpack("<f", struct.pack("<f", math.pi))
assert op_str == "dword ptr [3.1415927410125732 (FLOAT)]"
@pytest.mark.skip(reason="changed implementation")
def test_float_variable():
"""If there is a variable at the address referenced by a float instruction,
use the name instead of calling into the float replacement handler."""
def name_lookup(addr: int, _: bool) -> Optional[str]:
return "g_myFloatVariable" if addr == 0x1234 else None
p = ParseAsm(name_lookup=name_lookup)
inst = DisasmLiteInst(0x1000, 6, "fld", "dword ptr [0x1234]")
(_, op_str) = p.sanitize(inst)
assert op_str == "dword ptr [g_myFloatVariable]"
def test_pointer_compare():
"""A loop on an array could get optimized into comparing on the address
that immediately follows the array. This may or may not be a valid address
and it may or may not be annotated. To avoid a situation where an
erroneous address value would get replaced with a placeholder and silently
pass the comparison check, we will only replace an immediate value on the
CMP instruction if it is a known address."""
# 0x1234 and 0x5555 are relocated and so are considered to be addresses.
def relocate_lookup(addr: int) -> bool:
return addr in (0x1234, 0x5555)
# Only 0x5555 is a "known" address
def name_lookup(addr: int, _: bool) -> Optional[str]:
return "hello" if addr == 0x5555 else None
p = ParseAsm(relocate_lookup=relocate_lookup, name_lookup=name_lookup)
# Will always replace on MOV instruction
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x1234"))
assert op_str == "eax, <OFFSET1>"
(_, op_str) = p.sanitize(mock_inst("mov", "eax, 0x5555"))
assert op_str == "eax, hello"
# n.b. We have already cached the replacement for 0x1234, but the
# special handling for CMP should skip the cache and not use it.
# Do not replace here
(_, op_str) = p.sanitize(mock_inst("cmp", "eax, 0x1234"))
assert op_str == "eax, 0x1234"
# Should replace here
(_, op_str) = p.sanitize(mock_inst("cmp", "eax, 0x5555"))
assert op_str == "eax, hello"
def test_absolute_indirect():
"""The instruction `call dword ptr [0x1234]` means we call the function
whose address is at 0x1234. (i.e. absolute indirect addressing mode)
It is probably more useful to show the name of the function itself if
we have it, but there are some circumstances where we want to replace
with the pointer's name (i.e. an import function)."""
def name_lookup(addr: int, _: bool) -> Optional[str]:
return {
0x1234: "Hello",
0x4321: "xyz",
0x5555: "Test",
}.get(addr)
def bin_lookup(addr: int, _: int) -> Optional[bytes]:
return (
{
0x1234: b"\x55\x55\x00\x00",
0x4321: b"\x99\x99\x00\x00",
}
).get(addr)
p = ParseAsm(name_lookup=name_lookup, bin_lookup=bin_lookup)
# If we know the indirect address (0x5555)
# Arrow to indicate this is an indirect replacement
(_, op_str) = p.sanitize(mock_inst("call", "dword ptr [0x1234]"))
assert op_str == "dword ptr [->Test]"
# If we do not know the indirect address (0x9999)
(_, op_str) = p.sanitize(mock_inst("call", "dword ptr [0x4321]"))
assert op_str == "dword ptr [xyz]"
# If we can't read the indirect address
(_, op_str) = p.sanitize(mock_inst("call", "dword ptr [0x5555]"))
assert op_str == "dword ptr [Test]"

View file

@ -1,867 +0,0 @@
// reccmp.js
/* global data */
// Unwrap array of functions into a dictionary with address as the key.
const dataDict = Object.fromEntries(data.map(row => [row.address, row]));
function getDataByAddr(addr) {
return dataDict[addr];
}
//
// Pure functions
//
function formatAsm(entries, addrOption) {
const output = [];
const createTh = (text) => {
const th = document.createElement('th');
th.innerText = text;
return th;
};
const createTd = (text, className = '') => {
const td = document.createElement('td');
td.innerText = text;
td.className = className;
return td;
};
entries.forEach(obj => {
// These won't all be present. You get "both" for an equal node
// and orig/recomp for a diff.
const { both = [], orig = [], recomp = [] } = obj;
output.push(...both.map(([addr, line, recompAddr]) => {
const tr = document.createElement('tr');
tr.appendChild(createTh(addr));
tr.appendChild(createTh(recompAddr));
tr.appendChild(createTd(line));
return tr;
}));
output.push(...orig.map(([addr, line]) => {
const tr = document.createElement('tr');
tr.appendChild(createTh(addr));
tr.appendChild(createTh(''));
tr.appendChild(createTd(`-${line}`, 'diffneg'));
return tr;
}));
output.push(...recomp.map(([addr, line]) => {
const tr = document.createElement('tr');
tr.appendChild(createTh(''));
tr.appendChild(createTh(addr));
tr.appendChild(createTd(`+${line}`, 'diffpos'));
return tr;
}));
});
return output;
}
// Special internal values to ensure this sort order for matching column:
// 1. Stub
// 2. Any match percentage [0.0, 1.0)
// 3. Effective match
// 4. Actual 100% match
function matchingColAdjustment(row) {
if ('stub' in row) {
return -1;
}
if ('effective' in row) {
return 1.0;
}
if (row.matching === 1.0) {
return 1000;
}
return row.matching;
}
function getCppClass(str) {
const idx = str.indexOf('::');
if (idx !== -1) {
return str.slice(0, idx);
}
return str;
}
// Clamp string length to specified length and pad with ellipsis
function stringTruncate(str, maxlen = 20) {
str = getCppClass(str);
if (str.length > maxlen) {
return `${str.slice(0, maxlen)}...`;
}
return str;
}
function getMatchPercentText(row) {
if ('stub' in row) {
return 'stub';
}
if ('effective' in row) {
return '100.00%*';
}
return (row.matching * 100).toFixed(2) + '%';
}
function countDiffs(row) {
const { diff = '' } = row;
if (diff === '') {
return '';
}
const diffs = diff.map(([slug, subgroups]) => subgroups).flat();
const diffLength = diffs.filter(d => !('both' in d)).length;
const diffWord = diffLength === 1 ? 'diff' : 'diffs';
return diffLength === 0 ? '' : `${diffLength} ${diffWord}`;
}
// Helper for this set/remove attribute block
function setBooleanAttribute(element, attribute, value) {
if (value) {
element.setAttribute(attribute, '');
} else {
element.removeAttribute(attribute);
}
}
function copyToClipboard(value) {
navigator.clipboard.writeText(value);
}
const PAGE_SIZE = 200;
//
// Global state
//
class ListingState {
constructor() {
this._query = '';
this._sortCol = 'address';
this._filterType = 1;
this._sortDesc = false;
this._hidePerfect = false;
this._hideStub = false;
this._showRecomp = false;
this._expanded = {};
this._page = 0;
this._listeners = [];
this._results = [];
this.updateResults();
}
addListener(fn) {
this._listeners.push(fn);
}
callListeners() {
for (const fn of this._listeners) {
fn();
}
}
isExpanded(addr) {
return addr in this._expanded;
}
toggleExpanded(addr) {
this.setExpanded(addr, !this.isExpanded(addr));
}
setExpanded(addr, value) {
if (value) {
this._expanded[addr] = true;
} else {
delete this._expanded[addr];
}
}
updateResults() {
const filterFn = this.rowFilterFn.bind(this);
const sortFn = this.rowSortFn.bind(this);
this._results = data.filter(filterFn).sort(sortFn);
// Set _page directly to avoid double call to listeners.
this._page = this.pageClamp(this.page);
this.callListeners();
}
pageSlice() {
return this._results.slice(this.page * PAGE_SIZE, (this.page + 1) * PAGE_SIZE);
}
resultsCount() {
return this._results.length;
}
pageCount() {
return Math.ceil(this._results.length / PAGE_SIZE);
}
maxPage() {
return Math.max(0, this.pageCount() - 1);
}
// A list showing the range of each page based on the sort column and direction.
pageHeadings() {
if (this._results.length === 0) {
return [];
}
const headings = [];
for (let i = 0; i < this.pageCount(); i++) {
const startIdx = i * PAGE_SIZE;
const endIdx = Math.min(this._results.length, ((i + 1) * PAGE_SIZE)) - 1;
let start = this._results[startIdx][this.sortCol];
let end = this._results[endIdx][this.sortCol];
if (this.sortCol === 'matching') {
start = getMatchPercentText(this._results[startIdx]);
end = getMatchPercentText(this._results[endIdx]);
}
headings.push([i, stringTruncate(start), stringTruncate(end)]);
}
return headings;
}
rowFilterFn(row) {
// Destructuring sets defaults for optional values from this object.
const {
effective = false,
stub = false,
diff = '',
name,
address,
matching
} = row;
if (this.hidePerfect && (effective || matching >= 1)) {
return false;
}
if (this.hideStub && stub) {
return false;
}
if (this.query === '') {
return true;
}
// Name/addr search
if (this.filterType === 1) {
return (
address.includes(this.query) ||
name.toLowerCase().includes(this.query)
);
}
// no diff for review.
if (diff === '') {
return false;
}
// special matcher for combined diff
const anyLineMatch = ([addr, line]) => line.toLowerCase().trim().includes(this.query);
// Flatten all diff groups for the search
const diffs = diff.map(([slug, subgroups]) => subgroups).flat();
for (const subgroup of diffs) {
const { both = [], orig = [], recomp = [] } = subgroup;
// If search includes context
if (this.filterType === 2 && both.some(anyLineMatch)) {
return true;
}
if (orig.some(anyLineMatch) || recomp.some(anyLineMatch)) {
return true;
}
}
return false;
}
rowSortFn(rowA, rowB) {
const valA = this.sortCol === 'matching'
? matchingColAdjustment(rowA)
: rowA[this.sortCol];
const valB = this.sortCol === 'matching'
? matchingColAdjustment(rowB)
: rowB[this.sortCol];
if (valA > valB) {
return this.sortDesc ? -1 : 1;
} else if (valA < valB) {
return this.sortDesc ? 1 : -1;
}
return 0;
}
pageClamp(page) {
return Math.max(0, Math.min(page, this.maxPage()));
}
get page() {
return this._page;
}
set page(page) {
this._page = this.pageClamp(page);
this.callListeners();
}
get filterType() {
return parseInt(this._filterType);
}
set filterType(value) {
value = parseInt(value);
if (value >= 1 && value <= 3) {
this._filterType = value;
}
this.updateResults();
}
get query() {
return this._query;
}
set query(value) {
// Normalize search string
this._query = value.toLowerCase().trim();
this.updateResults();
}
get showRecomp() {
return this._showRecomp;
}
set showRecomp(value) {
// Don't sort by the recomp column we are about to hide
if (!value && this.sortCol === 'recomp') {
this._sortCol = 'address';
}
this._showRecomp = value;
this.callListeners();
}
get sortCol() {
return this._sortCol;
}
set sortCol(column) {
if (column === this._sortCol) {
this._sortDesc = !this._sortDesc;
} else {
this._sortCol = column;
}
this.updateResults();
}
get sortDesc() {
return this._sortDesc;
}
set sortDesc(value) {
this._sortDesc = value;
this.updateResults();
}
get hidePerfect() {
return this._hidePerfect;
}
set hidePerfect(value) {
this._hidePerfect = value;
this.updateResults();
}
get hideStub() {
return this._hideStub;
}
set hideStub(value) {
this._hideStub = value;
this.updateResults();
}
}
const appState = new ListingState();
//
// Custom elements
//
// Sets sort indicator arrow based on element attributes.
class SortIndicator extends window.HTMLElement {
static observedAttributes = ['data-sort'];
attributeChangedCallback(name, oldValue, newValue) {
if (newValue === null) {
// Reserve space for blank indicator so column width stays the same
this.innerHTML = '&nbsp;';
} else {
this.innerHTML = newValue === 'asc' ? '&#9650;' : '&#9660;';
}
}
}
class FuncRow extends window.HTMLElement {
connectedCallback() {
if (this.shadowRoot !== null) {
return;
}
const template = document.querySelector('template#funcrow-template').content;
const shadow = this.attachShadow({ mode: 'open' });
shadow.appendChild(template.cloneNode(true));
shadow.querySelector(':host > div[data-col="name"]').addEventListener('click', evt => {
this.dispatchEvent(new Event('name-click'));
});
}
get address() {
return this.getAttribute('data-address');
}
}
class NoDiffMessage extends window.HTMLElement {
connectedCallback() {
if (this.shadowRoot !== null) {
return;
}
const template = document.querySelector('template#nodiff-template').content;
const shadow = this.attachShadow({ mode: 'open' });
shadow.appendChild(template.cloneNode(true));
}
}
class CanCopy extends window.HTMLElement {
connectedCallback() {
if (this.shadowRoot !== null) {
return;
}
const template = document.querySelector('template#can-copy-template').content;
const shadow = this.attachShadow({ mode: 'open' });
shadow.appendChild(template.cloneNode(true));
const el = shadow.querySelector('slot').assignedNodes()[0];
el.addEventListener('mouseout', evt => { this.copied = false; });
el.addEventListener('click', evt => {
copyToClipboard(evt.target.textContent);
this.copied = true;
});
}
get copied() {
return this.getAttribute('copied');
}
set copied(value) {
if (value) {
setTimeout(() => { this.copied = false; }, 2000);
}
setBooleanAttribute(this, 'copied', value);
}
}
// Displays asm diff for the given @data-address value.
class DiffRow extends window.HTMLElement {
connectedCallback() {
if (this.shadowRoot !== null) {
return;
}
const template = document.querySelector('template#diffrow-template').content;
const shadow = this.attachShadow({ mode: 'open' });
shadow.appendChild(template.cloneNode(true));
}
get address() {
return this.getAttribute('data-address');
}
set address(value) {
this.setAttribute('data-address', value);
}
}
class DiffDisplayOptions extends window.HTMLElement {
static observedAttributes = ['data-option'];
connectedCallback() {
if (this.shadowRoot !== null) {
return;
}
const shadow = this.attachShadow({ mode: 'open' });
shadow.innerHTML = `
<style>
fieldset {
align-items: center;
display: flex;
margin-bottom: 20px;
}
label {
margin-right: 10px;
user-select: none;
}
label, input {
cursor: pointer;
}
</style>
<fieldset>
<legend>Address display:</legend>
<input type="radio" id="showNone" name="addrDisplay" value=0>
<label for="showNone">None</label>
<input type="radio" id="showOrig" name="addrDisplay" value=1>
<label for="showOrig">Original</label>
<input type="radio" id="showBoth" name="addrDisplay" value=2>
<label for="showBoth">Both</label>
</fieldset>`;
shadow.querySelectorAll('input[type=radio]').forEach(radio => {
const checked = this.option === radio.getAttribute('value');
setBooleanAttribute(radio, 'checked', checked);
radio.addEventListener('change', evt => (this.option = evt.target.value));
});
}
set option(value) {
this.setAttribute('data-option', parseInt(value));
}
get option() {
return this.getAttribute('data-option') ?? 1;
}
attributeChangedCallback(name, oldValue, newValue) {
if (name !== 'data-option') {
return;
}
this.dispatchEvent(new Event('change'));
}
}
class DiffDisplay extends window.HTMLElement {
static observedAttributes = ['data-option'];
connectedCallback() {
if (this.querySelector('diff-display-options') !== null) {
return;
}
const optControl = new DiffDisplayOptions();
optControl.option = this.option;
optControl.addEventListener('change', evt => (this.option = evt.target.option));
this.appendChild(optControl);
const div = document.createElement('div');
const obj = getDataByAddr(this.address);
const createHeaderLine = (text, className) => {
const div = document.createElement('div');
div.textContent = text;
div.className = className;
return div;
};
const groups = obj.diff;
groups.forEach(([slug, subgroups]) => {
const secondTable = document.createElement('table');
secondTable.classList.add('diffTable');
const hdr = document.createElement('div');
hdr.appendChild(createHeaderLine('---', 'diffneg'));
hdr.appendChild(createHeaderLine('+++', 'diffpos'));
hdr.appendChild(createHeaderLine(slug, 'diffslug'));
div.appendChild(hdr);
const tbody = document.createElement('tbody');
secondTable.appendChild(tbody);
const diffs = formatAsm(subgroups, this.option);
for (const el of diffs) {
tbody.appendChild(el);
}
div.appendChild(secondTable);
});
this.appendChild(div);
}
get address() {
return this.getAttribute('data-address');
}
set address(value) {
this.setAttribute('data-address', value);
}
get option() {
return this.getAttribute('data-option') ?? 1;
}
set option(value) {
this.setAttribute('data-option', value);
}
}
class ListingOptions extends window.HTMLElement {
constructor() {
super();
// Register to receive updates
appState.addListener(() => this.onUpdate());
const input = this.querySelector('input[type=search]');
input.oninput = evt => (appState.query = evt.target.value);
const hidePerf = this.querySelector('input#cbHidePerfect');
hidePerf.onchange = evt => (appState.hidePerfect = evt.target.checked);
hidePerf.checked = appState.hidePerfect;
const hideStub = this.querySelector('input#cbHideStub');
hideStub.onchange = evt => (appState.hideStub = evt.target.checked);
hideStub.checked = appState.hideStub;
const showRecomp = this.querySelector('input#cbShowRecomp');
showRecomp.onchange = evt => (appState.showRecomp = evt.target.checked);
showRecomp.checked = appState.showRecomp;
this.querySelector('button#pagePrev').addEventListener('click', evt => {
appState.page = appState.page - 1;
});
this.querySelector('button#pageNext').addEventListener('click', evt => {
appState.page = appState.page + 1;
});
this.querySelector('select#pageSelect').addEventListener('change', evt => {
appState.page = evt.target.value;
});
this.querySelectorAll('input[name=filterType]').forEach(radio => {
const checked = appState.filterType === parseInt(radio.getAttribute('value'));
setBooleanAttribute(radio, 'checked', checked);
radio.onchange = evt => (appState.filterType = radio.getAttribute('value'));
});
this.onUpdate();
}
onUpdate() {
// Update input placeholder based on search type
this.querySelector('input[type=search]').placeholder = appState.filterType === 1
? 'Search for offset or function name...'
: 'Search for instruction...';
// Update page number and max page
this.querySelector('fieldset#pageDisplay > legend').textContent = `Page ${appState.page + 1} of ${Math.max(1, appState.pageCount())}`;
// Disable prev/next buttons on first/last page
setBooleanAttribute(this.querySelector('button#pagePrev'), 'disabled', appState.page === 0);
setBooleanAttribute(this.querySelector('button#pageNext'), 'disabled', appState.page === appState.maxPage());
// Update page select dropdown
const pageSelect = this.querySelector('select#pageSelect');
setBooleanAttribute(pageSelect, 'disabled', appState.resultsCount() === 0);
pageSelect.innerHTML = '';
if (appState.resultsCount() === 0) {
const opt = document.createElement('option');
opt.textContent = '- no results -';
pageSelect.appendChild(opt);
} else {
for (const row of appState.pageHeadings()) {
const opt = document.createElement('option');
opt.value = row[0];
if (appState.page === row[0]) {
opt.setAttribute('selected', '');
}
const [start, end] = [row[1], row[2]];
opt.textContent = `${appState.sortCol}: ${start} to ${end}`;
pageSelect.appendChild(opt);
}
}
// Update row count
this.querySelector('#rowcount').textContent = `${appState.resultsCount()}`;
}
}
// Main application.
class ListingTable extends window.HTMLElement {
constructor() {
super();
// Register to receive updates
appState.addListener(() => this.somethingChanged());
}
setDiffRow(address, shouldExpand) {
const tbody = this.querySelector('tbody');
const funcrow = tbody.querySelector(`func-row[data-address="${address}"]`);
if (funcrow === null) {
return;
}
const existing = tbody.querySelector(`diff-row[data-address="${address}"]`);
if (existing !== null) {
if (!shouldExpand) {
tbody.removeChild(existing);
}
return;
}
const diffrow = document.createElement('diff-row');
diffrow.address = address;
// Decide what goes inside the diff row.
const obj = getDataByAddr(address);
if ('stub' in obj) {
const msg = document.createElement('no-diff');
const p = document.createElement('div');
p.innerText = 'Stub. No diff.';
msg.appendChild(p);
diffrow.appendChild(msg);
} else if (obj.diff.length === 0) {
const msg = document.createElement('no-diff');
const p = document.createElement('div');
p.innerText = 'Identical function - no diff';
msg.appendChild(p);
diffrow.appendChild(msg);
} else {
const dd = new DiffDisplay();
dd.option = '1';
dd.address = address;
diffrow.appendChild(dd);
}
// Insert the diff row after the parent func row.
tbody.insertBefore(diffrow, funcrow.nextSibling);
}
connectedCallback() {
const thead = this.querySelector('thead');
const headers = thead.querySelectorAll('th:not([data-no-sort])'); // TODO
headers.forEach(th => {
const col = th.getAttribute('data-col');
if (col) {
const span = th.querySelector('span');
if (span) {
span.addEventListener('click', evt => { appState.sortCol = col; });
}
}
});
this.somethingChanged();
}
somethingChanged() {
// Toggle recomp/diffs column
setBooleanAttribute(this.querySelector('table'), 'show-recomp', appState.showRecomp);
this.querySelectorAll('func-row[data-address]').forEach(row => {
setBooleanAttribute(row, 'show-recomp', appState.showRecomp);
});
const thead = this.querySelector('thead');
const headers = thead.querySelectorAll('th');
// Update sort indicator
headers.forEach(th => {
const col = th.getAttribute('data-col');
const indicator = th.querySelector('sort-indicator');
if (indicator === null) {
return;
}
if (appState.sortCol === col) {
indicator.setAttribute('data-sort', appState.sortDesc ? 'desc' : 'asc');
} else {
indicator.removeAttribute('data-sort');
}
});
// Add the rows
const tbody = this.querySelector('tbody');
tbody.innerHTML = ''; // ?
for (const obj of appState.pageSlice()) {
const row = document.createElement('func-row');
row.setAttribute('data-address', obj.address); // ?
row.addEventListener('name-click', evt => {
appState.toggleExpanded(obj.address);
this.setDiffRow(obj.address, appState.isExpanded(obj.address));
});
setBooleanAttribute(row, 'show-recomp', appState.showRecomp);
setBooleanAttribute(row, 'expanded', appState.isExpanded(row));
const items = [
['address', obj.address],
['recomp', obj.recomp],
['name', obj.name],
['diffs', countDiffs(obj)],
['matching', getMatchPercentText(obj)]
];
items.forEach(([slotName, content]) => {
const div = document.createElement('span');
div.setAttribute('slot', slotName);
div.innerText = content;
row.appendChild(div);
});
tbody.appendChild(row);
if (appState.isExpanded(obj.address)) {
this.setDiffRow(obj.address, true);
}
}
}
}
window.onload = () => {
window.customElements.define('listing-table', ListingTable);
window.customElements.define('listing-options', ListingOptions);
window.customElements.define('diff-display', DiffDisplay);
window.customElements.define('diff-display-options', DiffDisplayOptions);
window.customElements.define('sort-indicator', SortIndicator);
window.customElements.define('func-row', FuncRow);
window.customElements.define('diff-row', DiffRow);
window.customElements.define('no-diff', NoDiffMessage);
window.customElements.define('can-copy', CanCopy);
};

View file

@ -1,344 +0,0 @@
#!/usr/bin/env python3
import argparse
import base64
import json
import logging
import os
from datetime import datetime
from isledecomp import (
Bin,
get_file_in_script_dir,
print_combined_diff,
diff_json,
percent_string,
)
from isledecomp.compare import Compare as IsleCompare
from isledecomp.types import SymbolType
from pystache import Renderer
import colorama
colorama.just_fix_windows_console()
def gen_json(json_file: str, orig_file: str, data):
"""Create a JSON file that contains the comparison summary"""
# If the structure of the JSON file ever changes, we would run into a problem
# reading an older format file in the CI action. Mark which version we are
# generating so we could potentially address this down the road.
json_format_version = 1
# Remove the diff field
reduced_data = [
{key: value for (key, value) in obj.items() if key != "diff"} for obj in data
]
with open(json_file, "w", encoding="utf-8") as f:
json.dump(
{
"file": os.path.basename(orig_file).lower(),
"format": json_format_version,
"timestamp": datetime.now().timestamp(),
"data": reduced_data,
},
f,
)
def gen_html(html_file, data):
js_path = get_file_in_script_dir("reccmp.js")
with open(js_path, "r", encoding="utf-8") as f:
reccmp_js = f.read()
output_data = Renderer().render_path(
get_file_in_script_dir("template.html"), {"data": data, "reccmp_js": reccmp_js}
)
with open(html_file, "w", encoding="utf-8") as htmlfile:
htmlfile.write(output_data)
def gen_svg(svg_file, name_svg, icon, svg_implemented_funcs, total_funcs, raw_accuracy):
icon_data = None
if icon:
with open(icon, "rb") as iconfile:
icon_data = base64.b64encode(iconfile.read()).decode("utf-8")
total_statistic = raw_accuracy / total_funcs
full_percentbar_width = 127.18422
output_data = Renderer().render_path(
get_file_in_script_dir("template.svg"),
{
"name": name_svg,
"icon": icon_data,
"implemented": f"{(svg_implemented_funcs / total_funcs * 100):.2f}% ({svg_implemented_funcs}/{total_funcs})",
"accuracy": f"{(raw_accuracy / svg_implemented_funcs * 100):.2f}%",
"progbar": total_statistic * full_percentbar_width,
"percent": f"{(total_statistic * 100):.2f}%",
},
)
with open(svg_file, "w", encoding="utf-8") as svgfile:
svgfile.write(output_data)
def print_match_verbose(match, show_both_addrs: bool = False, is_plain: bool = False):
percenttext = percent_string(
match.effective_ratio, match.is_effective_match, is_plain
)
if show_both_addrs:
addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}"
else:
addrs = hex(match.orig_addr)
if match.is_stub:
print(f"{addrs}: {match.name} is a stub. No diff.")
return
if match.effective_ratio == 1.0:
ok_text = (
"OK!"
if is_plain
else (colorama.Fore.GREEN + "✨ OK! ✨" + colorama.Style.RESET_ALL)
)
if match.ratio == 1.0:
print(f"{addrs}: {match.name} 100% match.\n\n{ok_text}\n\n")
else:
print(
f"{addrs}: {match.name} Effective 100% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n"
)
else:
print_combined_diff(match.udiff, is_plain, show_both_addrs)
print(
f"\n{match.name} is only {percenttext} similar to the original, diff above"
)
def print_match_oneline(match, show_both_addrs: bool = False, is_plain: bool = False):
percenttext = percent_string(
match.effective_ratio, match.is_effective_match, is_plain
)
if show_both_addrs:
addrs = f"0x{match.orig_addr:x} / 0x{match.recomp_addr:x}"
else:
addrs = hex(match.orig_addr)
if match.is_stub:
print(f" {match.name} ({addrs}) is a stub.")
else:
print(f" {match.name} ({addrs}) is {percenttext} similar to the original")
def parse_args() -> argparse.Namespace:
def virtual_address(value) -> int:
"""Helper method for argparse, verbose parameter"""
return int(value, 16)
parser = argparse.ArgumentParser(
allow_abbrev=False,
description="Recompilation Compare: compare an original EXE with a recompiled EXE + PDB.",
)
parser.add_argument(
"original", metavar="original-binary", help="The original binary"
)
parser.add_argument(
"recompiled", metavar="recompiled-binary", help="The recompiled binary"
)
parser.add_argument(
"pdb", metavar="recompiled-pdb", help="The PDB of the recompiled binary"
)
parser.add_argument(
"decomp_dir", metavar="decomp-dir", help="The decompiled source tree"
)
parser.add_argument(
"--total",
"-T",
metavar="<count>",
help="Total number of expected functions (improves total accuracy statistic)",
)
parser.add_argument(
"--verbose",
"-v",
metavar="<offset>",
type=virtual_address,
help="Print assembly diff for specific function (original file's offset)",
)
parser.add_argument(
"--json",
metavar="<file>",
help="Generate JSON file with match summary",
)
parser.add_argument(
"--diff",
metavar="<file>",
help="Diff against summary in JSON file",
)
parser.add_argument(
"--html",
"-H",
metavar="<file>",
help="Generate searchable HTML summary of status and diffs",
)
parser.add_argument(
"--no-color", "-n", action="store_true", help="Do not color the output"
)
parser.add_argument(
"--svg", "-S", metavar="<file>", help="Generate SVG graphic of progress"
)
parser.add_argument("--svg-icon", metavar="icon", help="Icon to use in SVG (PNG)")
parser.add_argument(
"--print-rec-addr",
action="store_true",
help="Print addresses of recompiled functions too",
)
parser.add_argument(
"--silent",
action="store_true",
help="Don't display text summary of matches",
)
parser.set_defaults(loglevel=logging.INFO)
parser.add_argument(
"--debug",
action="store_const",
const=logging.DEBUG,
dest="loglevel",
help="Print script debug information",
)
args = parser.parse_args()
if not os.path.isfile(args.original):
parser.error(f"Original binary {args.original} does not exist")
if not os.path.isfile(args.recompiled):
parser.error(f"Recompiled binary {args.recompiled} does not exist")
if not os.path.isfile(args.pdb):
parser.error(f"Symbols PDB {args.pdb} does not exist")
if not os.path.isdir(args.decomp_dir):
parser.error(f"Source directory {args.decomp_dir} does not exist")
return args
def main():
args = parse_args()
logging.basicConfig(level=args.loglevel, format="[%(levelname)s] %(message)s")
with Bin(args.original, find_str=True) as origfile, Bin(
args.recompiled
) as recompfile:
if args.verbose is not None:
# Mute logger events from compare engine
logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL)
isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir)
if args.loglevel == logging.DEBUG:
isle_compare.debug = True
print()
### Compare one or none.
if args.verbose is not None:
match = isle_compare.compare_address(args.verbose)
if match is None:
print(f"Failed to find a match at address 0x{args.verbose:x}")
return
print_match_verbose(
match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color
)
return
### Compare everything.
function_count = 0
total_accuracy = 0
total_effective_accuracy = 0
htmlinsert = []
for match in isle_compare.compare_all():
if not args.silent and args.diff is None:
print_match_oneline(
match, show_both_addrs=args.print_rec_addr, is_plain=args.no_color
)
if match.match_type == SymbolType.FUNCTION and not match.is_stub:
function_count += 1
total_accuracy += match.ratio
total_effective_accuracy += match.effective_ratio
# If html, record the diffs to an HTML file
html_obj = {
"address": f"0x{match.orig_addr:x}",
"recomp": f"0x{match.recomp_addr:x}",
"name": match.name,
"matching": match.effective_ratio,
}
if match.is_effective_match:
html_obj["effective"] = True
if match.udiff is not None:
html_obj["diff"] = match.udiff
if match.is_stub:
html_obj["stub"] = True
htmlinsert.append(html_obj)
# Compare with saved diff report.
if args.diff is not None:
with open(args.diff, "r", encoding="utf-8") as f:
saved_data = json.load(f)
diff_json(
saved_data,
htmlinsert,
args.original,
show_both_addrs=args.print_rec_addr,
is_plain=args.no_color,
)
## Generate files and show summary.
if args.json is not None:
gen_json(args.json, args.original, htmlinsert)
if args.html is not None:
gen_html(args.html, json.dumps(htmlinsert))
implemented_funcs = function_count
if args.total:
function_count = int(args.total)
if function_count > 0:
effective_accuracy = total_effective_accuracy / function_count * 100
actual_accuracy = total_accuracy / function_count * 100
print(
f"\nTotal effective accuracy {effective_accuracy:.2f}% across {function_count} functions ({actual_accuracy:.2f}% actual accuracy)"
)
if args.svg is not None:
gen_svg(
args.svg,
os.path.basename(args.original),
args.svg_icon,
implemented_funcs,
function_count,
total_effective_accuracy,
)
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -1,365 +0,0 @@
<!DOCTYPE html>
<html>
<head>
<title>Decompilation Status</title>
<style>
body {
background: #202020;
color: #f0f0f0;
font-family: sans-serif;
}
h1 {
text-align: center;
}
.main {
width: 800px;
max-width: 100%;
margin: auto;
}
#search {
width: 100%;
box-sizing: border-box;
background: #303030;
color: #f0f0f0;
border: 1px #f0f0f0 solid;
padding: 0.5em;
border-radius: 0.5em;
}
#search::placeholder {
color: #b0b0b0;
}
#listing {
width: 100%;
border-collapse: collapse;
font-family: monospace;
}
func-row:hover {
background: #404040 !important;
}
func-row:nth-child(odd of :not([hidden])), #listing > thead th {
background: #282828;
}
func-row:nth-child(even of :not([hidden])) {
background: #383838;
}
table#listing {
border: 1px #f0f0f0 solid;
}
#listing > thead th {
padding: 0.5em;
user-select: none;
width: 10%;
text-align: left;
}
#listing:not([show-recomp]) > thead th[data-col="recomp"] {
display: none;
}
#listing > thead th > div {
display: flex;
column-gap: 0.5em;
}
#listing > thead th > div > span {
cursor: pointer;
}
#listing > thead th > div > span:hover {
text-decoration: underline;
text-decoration-style: dotted;
}
#listing > thead th:last-child > div {
justify-content: right;
}
#listing > thead th[data-col="name"] {
width: 60%;
}
.diffneg {
color: #FF8080;
}
.diffpos {
color: #80FF80;
}
.diffslug {
color: #8080FF;
}
.identical {
font-style: italic;
text-align: center;
}
sort-indicator {
user-select: none;
}
.filters {
align-items: top;
display: flex;
font-size: 10pt;
justify-content: space-between;
margin: 0.5em 0 1em 0;
}
.filters > fieldset {
/* checkbox and radio buttons v-aligned with text */
align-items: center;
display: flex;
}
.filters > fieldset > input, .filters > fieldset > label {
cursor: pointer;
}
.filters > fieldset > label {
margin-right: 10px;
}
table.diffTable {
border-collapse: collapse;
}
table.diffTable:not(:last-child) {
/* visual gap *between* diff context groups */
margin-bottom: 40px;
}
table.diffTable td, table.diffTable th {
border: 0 none;
padding: 0 10px 0 0;
}
table.diffTable th {
/* don't break address if asm line is long */
word-break: keep-all;
}
diff-display[data-option="0"] th:nth-child(1) {
display: none;
}
diff-display[data-option="0"] th:nth-child(2),
diff-display[data-option="1"] th:nth-child(2) {
display: none;
}
label {
user-select: none;
}
#pageDisplay > button {
cursor: pointer;
padding: 0.25em 0.5em;
}
#pageDisplay select {
cursor: pointer;
padding: 0.25em;
margin: 0 0.5em;
}
p.rowcount {
align-self: flex-end;
font-size: 1.2em;
margin-bottom: 0;
}
</style>
<script>var data = {{{data}}};</script>
<script>{{{reccmp_js}}}</script>
</script>
</head>
<body>
<div class="main">
<h1>Decompilation Status</h1>
<listing-options>
<input id="search" type="search" placeholder="Search for offset or function name...">
<div class="filters">
<fieldset>
<legend>Options:</legend>
<input type="checkbox" id="cbHidePerfect" />
<label for="cbHidePerfect">Hide 100% match</label>
<input type="checkbox" id="cbHideStub" />
<label for="cbHideStub">Hide stubs</label>
<input type="checkbox" id="cbShowRecomp" />
<label for="cbShowRecomp">Show recomp address</label>
</fieldset>
<fieldset>
<legend>Search filters on:</legend>
<input type="radio" name="filterType" id="filterName" value=1 checked />
<label for="filterName">Name/address</label>
<input type="radio" name="filterType" id="filterAsm" value=2 />
<label for="filterAsm">Asm output</label>
<input type="radio" name="filterType" id="filterDiff" value=3 />
<label for="filterDiff">Asm diffs only</label>
</fieldset>
</div>
<div class="filters">
<p class="rowcount">Results: <span id="rowcount"></span></p>
<fieldset id="pageDisplay">
<legend>Page</legend>
<button id="pagePrev">prev</button>
<select id="pageSelect">
</select>
<button id="pageNext">next</button>
</fieldset>
</div>
</listing-options>
<listing-table>
<table id="listing">
<thead>
<tr>
<th data-col="address">
<div>
<span>Address</span>
<sort-indicator/>
</div>
</th>
<th data-col="recomp">
<div>
<span>Recomp</span>
<sort-indicator/>
</div>
</th>
<th data-col="name">
<div>
<span>Name</span>
<sort-indicator/>
</div>
</th>
<th data-col="diffs" data-no-sort></th>
<th data-col="matching">
<div>
<sort-indicator></sort-indicator>
<span>Matching</span>
</div>
</th>
</tr>
</thead>
<tbody>
</tbody>
</table>
</listing-table>
</div>
<template id="funcrow-template">
<style>
:host(:not([hidden])) {
display: table-row;
}
:host(:not([show-recomp])) > div[data-col="recomp"] {
display: none;
}
div[data-col="name"]:hover {
cursor: pointer;
}
div[data-col="name"]:hover > ::slotted(*) {
text-decoration: underline;
text-decoration-style: dotted;
}
::slotted(*:not([slot="name"])) {
white-space: nowrap;
}
:host > div {
border-top: 1px #f0f0f0 solid;
display: table-cell;
padding: 0.5em;
word-break: break-all !important;
}
:host > div:last-child {
text-align: right;
}
</style>
<div data-col="address"><can-copy><slot name="address"></slot></can-copy></div>
<div data-col="recomp"><can-copy><slot name="recomp"></slot></can-copy></div>
<div data-col="name"><slot name="name"></slot></div>
<div data-col="diffs"><slot name="diffs"></slot></div>
<div data-col="matching"><slot name="matching"></slot></div>
</template>
<template id="diffrow-template">
<style>
:host(:not([hidden])) {
display: table-row;
contain: paint;
}
td.singleCell {
border: 1px #f0f0f0 solid;
border-bottom: 0px none;
display: table-cell;
padding: 0.5em;
word-break: break-all !important;
}
</style>
<td class="singleCell" colspan="5">
<slot></slot>
</td>
</template>
<template id="nodiff-template">
<style>
::slotted(*) {
font-style: italic;
text-align: center;
}
</style>
<slot></slot>
</template>
<template id="can-copy-template">
<style>
:host {
position: relative;
}
::slotted(*) {
cursor: pointer;
}
slot::after {
background-color: #fff;
color: #222;
display: none;
font-size: 12px;
padding: 1px 2px;
width: fit-content;
border-radius: 1px;
text-align: center;
bottom: 120%;
box-shadow: 0 4px 14px 0 rgba(0,0,0,.2), 0 0 0 1px rgba(0,0,0,.05);
position: absolute;
white-space: nowrap;
transition: .1s;
content: 'Copy to clipboard';
}
::slotted(*:hover) {
text-decoration: underline;
text-decoration-style: dotted;
}
slot:hover::after {
display: block;
}
:host([copied]) > slot:hover::after {
content: 'Copied!';
}
</style>
<slot></slot>
</template>
</body>
</html>

View file

@ -1,119 +0,0 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<!-- Created with Inkscape (http://www.inkscape.org/) -->
<svg
width="640"
height="480"
viewBox="0 0 169.33333 127"
version="1.1"
id="svg5"
xml:space="preserve"
sodipodi:docname="template.svg"
inkscape:version="1.2.2 (b0a8486541, 2022-12-01)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns:xlink="http://www.w3.org/1999/xlink"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg"><sodipodi:namedview
id="namedview26"
pagecolor="#505050"
bordercolor="#eeeeee"
borderopacity="1"
inkscape:showpageshadow="0"
inkscape:pageopacity="0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#505050"
showgrid="false"
inkscape:zoom="1.6046875"
inkscape:cx="158.90944"
inkscape:cy="220.6037"
inkscape:window-width="2560"
inkscape:window-height="1379"
inkscape:window-x="0"
inkscape:window-y="0"
inkscape:window-maximized="1"
inkscape:current-layer="g1273" /><defs
id="defs5">
<clipPath
id="progBarCutoff">
<rect
width="{{progbar}}"
height="8.6508904"
x="21.118132"
y="134.05507"
id="rect2" />
</clipPath>
</defs><g
id="g1273"
transform="matrix(1.2683581,0,0,1.2683581,-22.720969,-65.913871)"><image
width="53.066437"
height="53.066437"
preserveAspectRatio="none"
style="image-rendering:optimizeSpeed"
xlink:href="data:image/png;base64,{{icon}}"
id="image1060"
x="58.13345"
y="51.967873" /><text
xml:space="preserve"
style="font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;font-size:12.7px;font-family:monospace;-inkscape-font-specification:mono;text-align:center;text-anchor:middle;fill:#ffffff;stroke:#000000;stroke-width:1.25161812;stroke-opacity:1;stroke-dasharray:none;paint-order:stroke fill markers"
x="84.666656"
y="118.35877"
id="text740"><tspan
id="tspan738"
style="font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;font-family:monospace;-inkscape-font-specification:mono;text-align:center;text-anchor:middle;stroke:#000000;stroke-width:1.25161812;stroke-opacity:1;stroke-dasharray:none;paint-order:stroke fill markers"
x="84.666656"
y="118.35877">{{name}}</tspan></text><g
id="g1250"
transform="translate(-0.04358834,8.1397473)"><rect
style="display:inline;fill:none;fill-opacity:1;stroke:#000000;stroke-width:2.50324;stroke-dasharray:none;stroke-opacity:1"
id="rect1619"
width="127.18422"
height="8.6508904"
x="21.118132"
y="134.05507" /><rect
style="display:inline;fill:#000000;fill-opacity:1;stroke:#ffffff;stroke-width:0.87411;stroke-dasharray:none;stroke-opacity:1"
id="rect1167"
width="127.18422"
height="8.6508904"
x="21.118132"
y="134.05507" /><text
xml:space="preserve"
style="font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;font-size:4.23333px;font-family:monospace;-inkscape-font-specification:mono;text-align:start;text-anchor:start;fill:#ffffff;fill-opacity:1;stroke:none;stroke-width:1.05833;stroke-dasharray:none;stroke-opacity:1"
x="76.884926"
y="139.89182"
id="text2152"><tspan
style="font-size:4.23333px;fill:#ffffff;fill-opacity:1;stroke-width:1.05833"
x="76.884926"
y="139.89182"
id="tspan2150">{{percent}}</tspan></text><rect
style="display:inline;fill:#ffffff;stroke:none;stroke-width:2.6764"
id="rect1169"
width="127.18422"
height="8.6508904"
x="21.118132"
y="134.05507"
clip-path="url(#progBarCutoff)" /><text
xml:space="preserve"
style="font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;font-size:4.23333px;font-family:monospace;-inkscape-font-specification:mono;text-align:start;text-anchor:start;fill:#000000;fill-opacity:1;stroke:none;stroke-width:1.05833;stroke-dasharray:none;stroke-opacity:1"
x="76.884926"
y="139.89182"
id="text18"
clip-path="url(#progBarCutoff)"
inkscape:label="text18"><tspan
style="font-size:4.23333px;fill:#000000;fill-opacity:1;stroke-width:1.05833"
x="76.884926"
y="139.89182"
id="tspan16">{{percent}}</tspan></text></g><text
xml:space="preserve"
style="font-style:normal;font-variant:normal;font-weight:bold;font-stretch:normal;font-size:4.23333px;font-family:monospace;-inkscape-font-specification:mono;text-align:start;text-anchor:start;fill:#ffffff;fill-opacity:1;stroke:#000000;stroke-width:0.83441208;stroke-dasharray:none;stroke-opacity:1;opacity:1;stroke-linejoin:miter;stroke-linecap:butt;paint-order:stroke fill markers"
x="46.947659"
y="129.67447"
id="text1260"><tspan
id="tspan1258"
style="font-size:4.23333px;stroke-width:0.83441208;stroke:#000000;stroke-opacity:1;stroke-dasharray:none;stroke-linejoin:miter;stroke-linecap:butt;paint-order:stroke fill markers"
x="46.947659"
y="129.67447">Implemented: {{implemented}}</tspan><tspan
style="font-size:4.23333px;stroke-width:0.83441208;stroke:#000000;stroke-opacity:1;stroke-dasharray:none;stroke-linejoin:miter;stroke-linecap:butt;paint-order:stroke fill markers"
x="46.947659"
y="134.96613"
id="tspan1262">Accuracy: {{accuracy}}</tspan></text></g></svg>

Before

Width:  |  Height:  |  Size: 5.6 KiB

View file

@ -1,11 +1,3 @@
tools/isledecomp
capstone
reccmp @ git+https://github.com/isledecomp/reccmp
clang==16.*
colorama>=0.4.6
isledecomp
pystache
pyyaml
git+https://github.com/wbenny/pydemangler.git
# requirement of capstone due to python dropping distutils.
# see: https://github.com/capstone-engine/capstone/issues/2223
setuptools ; python_version >= "3.12"

View file

@ -1,494 +0,0 @@
"""For all addresses matched by code annotations or recomp pdb,
report how "far off" the recomp symbol is from its proper place
in the original binary."""
import os
import argparse
import logging
import statistics
import bisect
from typing import Iterator, List, Optional, Tuple
from collections import namedtuple
from isledecomp import Bin as IsleBin
from isledecomp.bin import InvalidVirtualAddressError
from isledecomp.cvdump import Cvdump
from isledecomp.compare import Compare as IsleCompare
from isledecomp.types import SymbolType
# Ignore all compare-db messages.
logging.getLogger("isledecomp.compare").addHandler(logging.NullHandler())
def or_blank(value) -> str:
"""Helper for dealing with potential None values in text output."""
return "" if value is None else str(value)
class ModuleMap:
"""Load a subset of sections from the pdb to allow you to look up the
module number based on the recomp address."""
def __init__(self, pdb, binfile) -> None:
cvdump = Cvdump(pdb).section_contributions().modules().run()
self.module_lookup = {m.id: (m.lib, m.obj) for m in cvdump.modules}
self.library_lookup = {m.obj: m.lib for m in cvdump.modules}
self.section_contrib = [
(
binfile.get_abs_addr(sizeref.section, sizeref.offset),
sizeref.size,
sizeref.module,
)
for sizeref in cvdump.sizerefs
if binfile.is_valid_section(sizeref.section)
]
# For bisect performance enhancement
self.contrib_starts = [start for (start, _, __) in self.section_contrib]
def get_lib_for_module(self, module: str) -> Optional[str]:
return self.library_lookup.get(module)
def get_all_cmake_modules(self) -> List[str]:
return [
obj
for (_, (__, obj)) in self.module_lookup.items()
if obj.startswith("CMakeFiles")
]
def get_module(self, addr: int) -> Optional[str]:
i = bisect.bisect_left(self.contrib_starts, addr)
# If the addr matches the section contribution start, we are in the
# right spot. Otherwise, we need to subtract one here.
# We don't want the insertion point given by bisect, but the
# section contribution that contains the address.
(potential_start, _, __) = self.section_contrib[i]
if potential_start != addr:
i -= 1
# Safety catch: clamp to range of indices from section_contrib.
i = max(0, min(i, len(self.section_contrib) - 1))
(start, size, module_id) = self.section_contrib[i]
if start <= addr < start + size:
if (module := self.module_lookup.get(module_id)) is not None:
return module
return None
def print_sections(sections):
print(" name | start | v.size | raw size")
print("---------|----------|----------|----------")
for sect in sections:
name = sect.name
print(
f"{name:>8} | {sect.virtual_address:8x} | {sect.virtual_size:8x} | {sect.size_of_raw_data:8x}"
)
print()
ALLOWED_TYPE_ABBREVIATIONS = ["fun", "dat", "poi", "str", "vta", "flo"]
def match_type_abbreviation(mtype: Optional[SymbolType]) -> str:
"""Return abbreviation of the given SymbolType name"""
if mtype is None:
return ""
return mtype.name.lower()[:3]
def get_cmakefiles_prefix(module: str) -> str:
"""For the given .obj, get the "CMakeFiles/something.dir/" prefix.
For lack of a better option, this is the library for this module."""
if module.startswith("CMakeFiles"):
return "/".join(module.split("/", 2)[:2]) + "/"
return module
def truncate_module_name(prefix: str, module: str) -> str:
"""Remove the CMakeFiles prefix and the .obj suffix for the given module.
Input: CMakeFiles/lego1.dir/, CMakeFiles/lego1.dir/LEGO1/define.cpp.obj
Output: LEGO1/define.cpp"""
if module.startswith(prefix):
module = module[len(prefix) :]
if module.endswith(".obj"):
module = module[:-4]
return module
def avg_remove_outliers(entries: List[int]) -> int:
"""Compute the average from this list of entries (addresses)
after removing outlier values."""
if len(entries) == 1:
return entries[0]
avg = statistics.mean(entries)
sd = statistics.pstdev(entries)
return int(statistics.mean([e for e in entries if abs(e - avg) <= 2 * sd]))
RoadmapRow = namedtuple(
"RoadmapRow",
[
"orig_sect_ofs",
"recomp_sect_ofs",
"orig_addr",
"recomp_addr",
"displacement",
"sym_type",
"size",
"name",
"module",
],
)
class DeltaCollector:
"""Reads each row of the results and aggregates information about the
placement of each module."""
def __init__(self, match_type: str = "fun") -> None:
# The displacement for each symbol from each module
self.disp_map = {}
# Each address for each module
self.addresses = {}
# The earliest address for each module
self.earliest = {}
# String abbreviation for which symbol type we are checking
self.match_type = "fun"
match_type = str(match_type).strip().lower()[:3]
if match_type in ALLOWED_TYPE_ABBREVIATIONS:
self.match_type = match_type
def read_row(self, row: RoadmapRow):
if row.module is None:
return
if row.sym_type != self.match_type:
return
if row.orig_addr is not None:
if row.module not in self.addresses:
self.addresses[row.module] = []
self.addresses[row.module].append(row.orig_addr)
if row.orig_addr < self.earliest.get(row.module, 0xFFFFFFFFF):
self.earliest[row.module] = row.orig_addr
if row.displacement is not None:
if row.module not in self.disp_map:
self.disp_map[row.module] = []
self.disp_map[row.module].append(row.displacement)
def iter_sorted(self) -> Iterator[Tuple[int, int]]:
"""Compute the average address for each module, then generate them
in ascending order."""
avg_address = {
mod: avg_remove_outliers(values) for mod, values in self.addresses.items()
}
for mod, avg in sorted(avg_address.items(), key=lambda x: x[1]):
yield (avg, mod)
def suggest_order(results: List[RoadmapRow], module_map: ModuleMap, match_type: str):
"""Suggest the order of modules for CMakeLists.txt"""
dc = DeltaCollector(match_type)
for row in results:
dc.read_row(row)
# First, show the order of .obj files for the "CMake Modules"
# Meaning: the modules where the .obj file begins with "CMakeFiles".
# These are the libraries where we directly control the order.
# The library name (from cvdump) doesn't make it obvious that these are
# our libraries so we derive the name based on the CMakeFiles prefix.
leftover_modules = set(module_map.get_all_cmake_modules())
# A little convoluted, but we want to take the first two tokens
# of the string with '/' as the delimiter.
# i.e. CMakeFiles/isle.dir/
# The idea is to print exactly what appears in CMakeLists.txt.
cmake_prefixes = sorted(set(get_cmakefiles_prefix(mod) for mod in leftover_modules))
# Save this off because we'll use it again later.
computed_order = list(dc.iter_sorted())
for prefix in cmake_prefixes:
print(prefix)
last_earliest = 0
# Show modules ordered by the computed average of addresses
for _, module in computed_order:
if not module.startswith(prefix):
continue
leftover_modules.remove(module)
avg_displacement = None
displacements = dc.disp_map.get(module)
if displacements is not None and len(displacements) > 0:
avg_displacement = int(statistics.mean(displacements))
# Call attention to any modules where ordering by earliest
# address is different from the computed order we display.
earliest = dc.earliest.get(module)
ooo_mark = "*" if earliest < last_earliest else " "
last_earliest = earliest
code_file = truncate_module_name(prefix, module)
print(f"0x{earliest:08x}{ooo_mark} {avg_displacement:10} {code_file}")
# These modules are included in the final binary (in some form) but
# don't contribute any symbols of the type we are checking.
# n.b. There could still be other modules that are part of
# CMakeLists.txt but are not included in the pdb for whatever reason.
# In other words: don't take the list we provide as the final word on
# what should or should not be included.
# This is merely a suggestion of the order.
for module in leftover_modules:
if not module.startswith(prefix):
continue
# aligned with previous print
code_file = truncate_module_name(prefix, module)
print(f" no suggestion {code_file}")
print()
# Now display the order of all libaries in the final file.
library_order = {}
for start, module in computed_order:
lib = module_map.get_lib_for_module(module)
if lib is None:
lib = get_cmakefiles_prefix(module)
if start < library_order.get(lib, 0xFFFFFFFFF):
library_order[lib] = start
print("Library order (average address shown):")
for lib, start in sorted(library_order.items(), key=lambda x: x[1]):
# Strip off any OS path for brevity
if not lib.startswith("CMakeFiles"):
lib = os.path.basename(lib)
print(f"{lib:40} {start:08x}")
def print_text_report(results: List[RoadmapRow]):
"""Print the result with original and recomp addresses."""
for row in results:
print(
" ".join(
[
f"{or_blank(row.orig_sect_ofs):14}",
f"{or_blank(row.recomp_sect_ofs):14}",
f"{or_blank(row.displacement):>8}",
f"{row.sym_type:3}",
f"{or_blank(row.size):6}",
or_blank(row.name),
]
)
)
def print_diff_report(results: List[RoadmapRow]):
"""Print only entries where we have the recomp address.
This is intended for generating a file to diff against.
The recomp addresses are always changing so we hide those."""
for row in results:
if row.orig_addr is None or row.recomp_addr is None:
continue
print(
" ".join(
[
f"{or_blank(row.orig_sect_ofs):14}",
f"{or_blank(row.displacement):>8}",
f"{row.sym_type:3}",
f"{or_blank(row.size):6}",
or_blank(row.name),
]
)
)
def export_to_csv(csv_file: str, results: List[RoadmapRow]):
with open(csv_file, "w+", encoding="utf-8") as f:
f.write(
"orig_sect_ofs,recomp_sect_ofs,orig_addr,recomp_addr,displacement,row_type,size,name,module\n"
)
for row in results:
f.write(",".join(map(or_blank, row)))
f.write("\n")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Show all addresses from original and recomp."
)
parser.add_argument(
"original", metavar="original-binary", help="The original binary"
)
parser.add_argument(
"recompiled", metavar="recompiled-binary", help="The recompiled binary"
)
parser.add_argument(
"pdb", metavar="recompiled-pdb", help="The PDB of the recompiled binary"
)
parser.add_argument(
"decomp_dir", metavar="decomp-dir", help="The decompiled source tree"
)
parser.add_argument("--csv", metavar="<file>", help="If set, export to CSV")
parser.add_argument(
"--verbose", "-v", action="store_true", help="Show recomp addresses in output"
)
parser.add_argument(
"--order",
const="fun",
nargs="?",
type=str,
help="Show suggested order of modules (using the specified symbol type)",
)
(args, _) = parser.parse_known_args()
if not os.path.isfile(args.original):
parser.error(f"Original binary {args.original} does not exist")
if not os.path.isfile(args.recompiled):
parser.error(f"Recompiled binary {args.recompiled} does not exist")
if not os.path.isfile(args.pdb):
parser.error(f"Symbols PDB {args.pdb} does not exist")
if not os.path.isdir(args.decomp_dir):
parser.error(f"Source directory {args.decomp_dir} does not exist")
return args
def main():
args = parse_args()
with IsleBin(args.original, find_str=True) as orig_bin, IsleBin(
args.recompiled
) as recomp_bin:
engine = IsleCompare(orig_bin, recomp_bin, args.pdb, args.decomp_dir)
module_map = ModuleMap(args.pdb, recomp_bin)
def is_same_section(orig: int, recomp: int) -> bool:
"""Compare the section name instead of the index.
LEGO1.dll adds extra sections for some reason. (Smacker library?)"""
try:
orig_name = orig_bin.sections[orig - 1].name
recomp_name = recomp_bin.sections[recomp - 1].name
return orig_name == recomp_name
except IndexError:
return False
def to_roadmap_row(match):
orig_sect = None
orig_ofs = None
orig_sect_ofs = None
recomp_sect = None
recomp_ofs = None
recomp_sect_ofs = None
orig_addr = None
recomp_addr = None
displacement = None
module_name = None
if match.recomp_addr is not None and recomp_bin.is_valid_vaddr(
match.recomp_addr
):
if (module_ref := module_map.get_module(match.recomp_addr)) is not None:
(_, module_name) = module_ref
row_type = match_type_abbreviation(match.compare_type)
name = (
repr(match.name)
if match.compare_type == SymbolType.STRING
else match.name
)
if match.orig_addr is not None:
orig_addr = match.orig_addr
(orig_sect, orig_ofs) = orig_bin.get_relative_addr(match.orig_addr)
orig_sect_ofs = f"{orig_sect:04}:{orig_ofs:08x}"
if match.recomp_addr is not None:
recomp_addr = match.recomp_addr
(recomp_sect, recomp_ofs) = recomp_bin.get_relative_addr(
match.recomp_addr
)
recomp_sect_ofs = f"{recomp_sect:04}:{recomp_ofs:08x}"
if (
orig_sect is not None
and recomp_sect is not None
and is_same_section(orig_sect, recomp_sect)
):
displacement = recomp_ofs - orig_ofs
return RoadmapRow(
orig_sect_ofs,
recomp_sect_ofs,
orig_addr,
recomp_addr,
displacement,
row_type,
match.size,
name,
module_name,
)
def roadmap_row_generator(matches):
for match in matches:
try:
yield to_roadmap_row(match)
except InvalidVirtualAddressError:
# This is here to work around the fact that we have RVA
# values (i.e. not real virtual addrs) in our compare db.
pass
results = list(roadmap_row_generator(engine.get_all()))
if args.order is not None:
suggest_order(results, module_map, args.order)
return
if args.csv is None:
if args.verbose:
print("ORIG sections:")
print_sections(orig_bin.sections)
print("RECOMP sections:")
print_sections(recomp_bin.sections)
print_text_report(results)
else:
print_diff_report(results)
if args.csv is not None:
export_to_csv(args.csv, results)
if __name__ == "__main__":
main()

View file

@ -1,364 +0,0 @@
from dataclasses import dataclass
import re
import logging
import os
import argparse
import struct
from typing import Dict, List, NamedTuple, Optional, Set, Tuple
from isledecomp import Bin
from isledecomp.compare import Compare as IsleCompare
from isledecomp.compare.diff import CombinedDiffOutput
from isledecomp.cvdump.symbols import SymbolsEntry
import colorama
# pylint: disable=duplicate-code # misdetects a code duplication with reccmp
colorama.just_fix_windows_console()
CHECK_ICON = f"{colorama.Fore.GREEN}{colorama.Style.RESET_ALL}"
SWAP_ICON = f"{colorama.Fore.YELLOW}{colorama.Style.RESET_ALL}"
ERROR_ICON = f"{colorama.Fore.RED}{colorama.Style.RESET_ALL}"
UNCLEAR_ICON = f"{colorama.Fore.BLUE}?{colorama.Style.RESET_ALL}"
STACK_ENTRY_REGEX = re.compile(
r"(?P<register>e[sb]p)\s(?P<sign>[+-])\s(?P<offset>(0x)?[0-9a-f]+)(?![0-9a-f])"
)
@dataclass
class StackSymbol:
name: str
data_type: str
@dataclass
class StackRegisterOffset:
register: str
offset: int
symbol: Optional[StackSymbol] = None
def __str__(self) -> str:
first_part = (
f"{self.register} + {self.offset:#04x}"
if self.offset > 0
else f"{self.register} - {-self.offset:#04x}"
)
second_part = f" {self.symbol.name}" if self.symbol else ""
return first_part + second_part
def __hash__(self) -> int:
return hash(self.register) + self.offset
def copy(self) -> "StackRegisterOffset":
return StackRegisterOffset(self.register, self.offset, self.symbol)
def __eq__(self, other: "StackRegisterOffset"):
return self.register == other.register and self.offset == other.offset
class StackPair(NamedTuple):
orig: StackRegisterOffset
recomp: StackRegisterOffset
StackPairs = Set[StackPair]
@dataclass
class Warnings:
structural_mismatches_present: bool = False
error_map_not_bijective: bool = False
def extract_stack_offset_from_instruction(
instruction: str,
) -> StackRegisterOffset | None:
match = STACK_ENTRY_REGEX.search(instruction)
if not match:
return None
offset = int(match.group("sign") + match.group("offset"), 16)
return StackRegisterOffset(match.group("register"), offset)
def analyze_diff(
diff: Dict[str, List[Tuple[str, ...]]], warnings: Warnings
) -> StackPairs:
stack_pairs: StackPairs = set()
if "both" in diff:
# get the matching stack entries
for line in diff["both"]:
# 0 = orig addr, 1 = instruction, 2 = reccmp addr
instruction = line[1]
if match := extract_stack_offset_from_instruction(instruction):
logging.debug("stack match: %s", match)
# need a copy for recomp because we might add a debug symbol to it
stack_pairs.add(StackPair(match, match.copy()))
elif any(x in instruction for x in ["ebp", "esp"]):
logging.debug("not a stack offset: %s", instruction)
else:
orig = diff["orig"]
recomp = diff["recomp"]
if len(orig) != len(recomp):
if orig:
mismatch_location = f"orig={orig[0][0]}"
else:
mismatch_location = f"recomp={recomp[0][0]}"
logging.error(
"Structural mismatch at %s:\n%s",
mismatch_location,
print_structural_mismatch(orig, recomp),
)
warnings.structural_mismatches_present = True
return set()
for orig_line, recomp_line in zip(orig, recomp):
if orig_match := extract_stack_offset_from_instruction(orig_line[1]):
recomp_match = extract_stack_offset_from_instruction(recomp_line[1])
if not recomp_match:
logging.error(
"Mismatching line structure at orig=%s:\n%s",
orig_line[0],
print_structural_mismatch(orig, recomp),
)
# not recoverable, whole block has a structural mismatch
warnings.structural_mismatches_present = True
return set()
stack_pair = StackPair(orig_match, recomp_match)
logging.debug(
"stack match, wrong order: %s vs %s", stack_pair[0], stack_pair[1]
)
stack_pairs.add(stack_pair)
elif any(x in orig_line[1] for x in ["ebp", "esp"]):
logging.debug("not a stack offset: %s", orig_line[1])
return stack_pairs
def print_bijective_match(left: str, right: str, exact: bool):
icon = CHECK_ICON if exact else SWAP_ICON
print(f"{icon}{colorama.Style.RESET_ALL} {left}: {right}")
def print_non_bijective_match(left: str, right: str):
print(f"{ERROR_ICON} {left}: {right}")
def print_structural_mismatch(
orig: List[Tuple[str, ...]], recomp: List[Tuple[str, ...]]
) -> str:
orig_str = "\n".join(f"-{x[1]}" for x in orig) if orig else "-"
recomp_str = "\n".join(f"+{x[1]}" for x in recomp) if recomp else "+"
return f"{colorama.Fore.RED}{orig_str}\n{colorama.Fore.GREEN}{recomp_str}\n{colorama.Style.RESET_ALL}"
def format_list_of_offsets(offsets: List[StackRegisterOffset]) -> str:
return str([str(x) for x in offsets])
def compare_function_stacks(udiff: CombinedDiffOutput, fn_symbol: SymbolsEntry):
warnings = Warnings()
# consists of pairs (orig, recomp)
# don't use a dict because we can have m:n relations
stack_pairs: StackPairs = set()
for block in udiff:
# block[0] is e.g. "@@ -0x10071662,60 +0x10031368,60 @@"
for diff in block[1]:
stack_pairs = stack_pairs.union(analyze_diff(diff, warnings))
# Note that the 'Frame Ptr Present' property is not relevant to the stack below `ebp`,
# but only to entries above (i.e. the function arguments on the stack).
# See also pdb_extraction.py.
stack_symbols: Dict[int, StackSymbol] = {}
for symbol in fn_symbol.stack_symbols:
if symbol.symbol_type == "S_BPREL32":
# convert hex to signed 32 bit integer
hex_bytes = bytes.fromhex(symbol.location[1:-1])
stack_offset = struct.unpack(">l", hex_bytes)[0]
stack_symbols[stack_offset] = StackSymbol(
symbol.name,
symbol.data_type,
)
for _, recomp in stack_pairs:
if recomp.register == "ebp":
recomp.symbol = stack_symbols.get(recomp.offset)
elif recomp.register == "esp":
logging.debug(
"Matching esp offsets to debug symbols is not implemented right now"
)
print("\nOrdered by original stack (left=orig, right=recomp):")
all_orig_offsets = set(x.orig.offset for x in stack_pairs)
for orig_offset in sorted(all_orig_offsets):
orig = next(x.orig for x in stack_pairs if x.orig.offset == orig_offset)
recomps = [x.recomp for x in stack_pairs if x.orig == orig]
if len(recomps) == 1:
recomp = recomps[0]
print_bijective_match(str(orig), str(recomp), exact=orig == recomp)
else:
print_non_bijective_match(str(orig), format_list_of_offsets(recomps))
warnings.error_map_not_bijective = True
# Show offsets from the debug symbols that we have not encountered in the diff
all_recomp_offsets = set(x.recomp.offset for x in stack_pairs).union(
stack_symbols.keys()
)
print("\nOrdered by recomp stack (left=orig, right=recomp):")
for recomp_offset in sorted(all_recomp_offsets):
recomp = next(
(x.recomp for x in stack_pairs if x.recomp.offset == recomp_offset), None
)
if recomp is None:
# The offset only appears in the debug symbols.
# The legend below explains why this can happen.
stack_offset = StackRegisterOffset(
"ebp", recomp_offset, stack_symbols[recomp_offset]
)
print(f"{UNCLEAR_ICON} not seen: {stack_offset}")
continue
origs = [x.orig for x in stack_pairs if x.recomp == recomp]
if len(origs) == 1:
# 1:1 clean match
print_bijective_match(str(origs[0]), str(recomp), origs[0] == recomp)
else:
print_non_bijective_match(format_list_of_offsets(origs), str(recomp))
warnings.error_map_not_bijective = True
print(
"\nLegend:\n"
+ f"{SWAP_ICON} : This stack variable matches 1:1, but the order of variables is not correct.\n"
+ f"{ERROR_ICON} : This stack variable matches multiple variables in the other binary.\n"
+ f"{UNCLEAR_ICON} : This stack variable did not appear in the diff. It either matches or only appears in structural mismatches.\n"
)
if warnings.error_map_not_bijective:
print(
"ERROR: The stack variables of original and recomp are not in a 1:1 correspondence, "
+ "suggesting that the logic in the recomp is incorrect."
)
elif warnings.structural_mismatches_present:
print(
"WARNING: Original and recomp have at least one structural discrepancy, "
+ "so the comparison of stack variables might be incomplete. "
+ "The structural mismatches above need to be checked manually."
)
def parse_args() -> argparse.Namespace:
def virtual_address(value) -> int:
"""Helper method for argparse, verbose parameter"""
return int(value, 16)
parser = argparse.ArgumentParser(
allow_abbrev=False,
description="Recompilation Compare: compare an original EXE with a recompiled EXE + PDB.",
)
parser.add_argument(
"original", metavar="original-binary", help="The original binary"
)
parser.add_argument(
"recompiled", metavar="recompiled-binary", help="The recompiled binary"
)
parser.add_argument(
"pdb", metavar="recompiled-pdb", help="The PDB of the recompiled binary"
)
parser.add_argument(
"decomp_dir", metavar="decomp-dir", help="The decompiled source tree"
)
parser.add_argument(
"address",
metavar="<offset>",
type=virtual_address,
help="The original file's offset of the function to be analyzed",
)
parser.set_defaults(loglevel=logging.INFO)
parser.add_argument(
"--debug",
action="store_const",
const=logging.DEBUG,
dest="loglevel",
help="Print script debug information",
)
args = parser.parse_args()
if not os.path.isfile(args.original):
parser.error(f"Original binary {args.original} does not exist")
if not os.path.isfile(args.recompiled):
parser.error(f"Recompiled binary {args.recompiled} does not exist")
if not os.path.isfile(args.pdb):
parser.error(f"Symbols PDB {args.pdb} does not exist")
if not os.path.isdir(args.decomp_dir):
parser.error(f"Source directory {args.decomp_dir} does not exist")
return args
def main():
args = parse_args()
logging.basicConfig(level=args.loglevel, format="[%(levelname)s] %(message)s")
with Bin(args.original, find_str=True) as origfile, Bin(
args.recompiled
) as recompfile:
if args.loglevel != logging.DEBUG:
# Mute logger events from compare engine
logging.getLogger("isledecomp.compare.core").setLevel(logging.CRITICAL)
logging.getLogger("isledecomp.compare.db").setLevel(logging.CRITICAL)
logging.getLogger("isledecomp.compare.lines").setLevel(logging.CRITICAL)
isle_compare = IsleCompare(origfile, recompfile, args.pdb, args.decomp_dir)
if args.loglevel == logging.DEBUG:
isle_compare.debug = True
print()
match = isle_compare.compare_address(args.address)
if match is None:
print(f"Failed to find a match at address 0x{args.address:x}")
return
assert match.udiff is not None
function_data = next(
(
y
for y in isle_compare.cvdump_analysis.nodes
if y.addr == match.recomp_addr
),
None,
)
assert function_data is not None
assert function_data.symbol_entry is not None
compare_function_stacks(match.udiff, function_data.symbol_entry)
if __name__ == "__main__":
raise SystemExit(main())

View file

@ -1,75 +0,0 @@
#!/usr/bin/env python3
import argparse
import difflib
import subprocess
import os
from isledecomp.lib import lib_path_join
from isledecomp.utils import print_diff
def main():
parser = argparse.ArgumentParser(
allow_abbrev=False,
description="Verify Exports: Compare the exports of two DLLs.",
)
parser.add_argument(
"original", metavar="original-binary", help="The original binary"
)
parser.add_argument(
"recompiled", metavar="recompiled-binary", help="The recompiled binary"
)
parser.add_argument(
"--no-color", "-n", action="store_true", help="Do not color the output"
)
args = parser.parse_args()
if not os.path.isfile(args.original):
parser.error(f"Original binary file {args.original} does not exist")
if not os.path.isfile(args.recompiled):
parser.error(f"Recompiled binary {args.recompiled} does not exist")
def get_exports(file):
call = [lib_path_join("DUMPBIN.EXE"), "/EXPORTS"]
if os.name != "nt":
call.insert(0, "wine")
file = (
subprocess.check_output(["winepath", "-w", file])
.decode("utf-8")
.strip()
)
call.append(file)
raw = subprocess.check_output(call).decode("utf-8").split("\r\n")
exports = []
start = False
for line in raw:
if not start:
if line == " ordinal hint name":
start = True
else:
if line:
exports.append(line[27 : line.rindex(" (")])
elif exports:
break
return exports
og_exp = get_exports(args.original)
re_exp = get_exports(args.recompiled)
udiff = difflib.unified_diff(og_exp, re_exp)
has_diff = print_diff(udiff, args.no_color)
return 1 if has_diff else 0
if __name__ == "__main__":
raise SystemExit(main())

Some files were not shown because too many files have changed in this diff Show more