Support stubs in function import (#1034)

* Refactor returned data structure for extensibility

* feature: Import stub functions but don't overwrite their argument list

Ghidra might have auto-detected some arguments, so we don't want to overwrite that if the stub's argument list has not been verified

Closes #1009

---------

Co-authored-by: jonschz <jonschz@users.noreply.github.com>
This commit is contained in:
jonschz 2024-06-16 13:13:19 +02:00 committed by GitHub
parent a6644801f1
commit c8dc77cbf4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 56 additions and 38 deletions

View file

@ -125,16 +125,15 @@ def add_python_path(path: str):
# We need to quote the types here because they might not exist when running without Ghidra # We need to quote the types here because they might not exist when running without Ghidra
def import_function_into_ghidra( def import_function_into_ghidra(
api: "FlatProgramAPI", api: "FlatProgramAPI",
match_info: "MatchInfo", pdb_function: "PdbFunction",
signature: "FunctionSignature",
type_importer: "PdbTypeImporter", type_importer: "PdbTypeImporter",
): ):
hex_original_address = f"{match_info.orig_addr:x}" hex_original_address = f"{pdb_function.match_info.orig_addr:x}"
# Find the Ghidra function at that address # Find the Ghidra function at that address
ghidra_address = getAddressFactory().getAddress(hex_original_address) ghidra_address = getAddressFactory().getAddress(hex_original_address)
# pylint: disable=possibly-used-before-assignment # pylint: disable=possibly-used-before-assignment
function_importer = PdbFunctionImporter(api, match_info, signature, type_importer) function_importer = PdbFunctionImporter(api, pdb_function, type_importer)
ghidra_function = getFunctionAt(ghidra_address) ghidra_function = getFunctionAt(ghidra_address)
if ghidra_function is None: if ghidra_function is None:
@ -165,7 +164,7 @@ def import_function_into_ghidra(
def process_functions(extraction: "PdbFunctionExtractor"): def process_functions(extraction: "PdbFunctionExtractor"):
func_signatures = extraction.get_function_list() pdb_functions = extraction.get_function_list()
if not GLOBALS.running_from_ghidra: if not GLOBALS.running_from_ghidra:
logger.info("Completed the dry run outside Ghidra.") logger.info("Completed the dry run outside Ghidra.")
@ -175,12 +174,13 @@ def process_functions(extraction: "PdbFunctionExtractor"):
# pylint: disable=possibly-used-before-assignment # pylint: disable=possibly-used-before-assignment
type_importer = PdbTypeImporter(api, extraction) type_importer = PdbTypeImporter(api, extraction)
for match_info, signature in func_signatures: for pdb_func in pdb_functions:
func_name = pdb_func.match_info.name
try: try:
import_function_into_ghidra(api, match_info, signature, type_importer) import_function_into_ghidra(api, pdb_func, type_importer)
GLOBALS.statistics.successes += 1 GLOBALS.statistics.successes += 1
except Lego1Exception as e: except Lego1Exception as e:
log_and_track_failure(match_info.name, e) log_and_track_failure(func_name, e)
except RuntimeError as e: except RuntimeError as e:
cause = e.args[0] cause = e.args[0]
if CancelledException is not None and isinstance(cause, CancelledException): if CancelledException is not None and isinstance(cause, CancelledException):
@ -188,10 +188,10 @@ def process_functions(extraction: "PdbFunctionExtractor"):
logging.critical("Import aborted by the user.") logging.critical("Import aborted by the user.")
return return
log_and_track_failure(match_info.name, cause, unexpected=True) log_and_track_failure(func_name, cause, unexpected=True)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
except Exception as e: # pylint: disable=broad-exception-caught except Exception as e: # pylint: disable=broad-exception-caught
log_and_track_failure(match_info.name, e, unexpected=True) log_and_track_failure(func_name, e, unexpected=True)
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
@ -257,7 +257,6 @@ def main():
from isledecomp.compare import Compare as IsleCompare from isledecomp.compare import Compare as IsleCompare
reload_module("isledecomp.compare.db") reload_module("isledecomp.compare.db")
from isledecomp.compare.db import MatchInfo
reload_module("lego_util.exceptions") reload_module("lego_util.exceptions")
from lego_util.exceptions import Lego1Exception from lego_util.exceptions import Lego1Exception
@ -265,7 +264,7 @@ def main():
reload_module("lego_util.pdb_extraction") reload_module("lego_util.pdb_extraction")
from lego_util.pdb_extraction import ( from lego_util.pdb_extraction import (
PdbFunctionExtractor, PdbFunctionExtractor,
FunctionSignature, PdbFunction,
) )
if GLOBALS.running_from_ghidra: if GLOBALS.running_from_ghidra:

View file

@ -11,10 +11,8 @@
from ghidra.program.model.listing import ParameterImpl from ghidra.program.model.listing import ParameterImpl
from ghidra.program.model.symbol import SourceType from ghidra.program.model.symbol import SourceType
from isledecomp.compare.db import MatchInfo
from lego_util.pdb_extraction import ( from lego_util.pdb_extraction import (
FunctionSignature, PdbFunction,
CppRegisterSymbol, CppRegisterSymbol,
CppStackSymbol, CppStackSymbol,
) )
@ -37,28 +35,28 @@ class PdbFunctionImporter:
def __init__( def __init__(
self, self,
api: FlatProgramAPI, api: FlatProgramAPI,
match_info: MatchInfo, func: PdbFunction,
signature: FunctionSignature,
type_importer: "PdbTypeImporter", type_importer: "PdbTypeImporter",
): ):
self.api = api self.api = api
self.match_info = match_info self.match_info = func.match_info
self.signature = signature self.signature = func.signature
self.is_stub = func.is_stub
self.type_importer = type_importer self.type_importer = type_importer
if signature.class_type is not None: if self.signature.class_type is not None:
# Import the base class so the namespace exists # Import the base class so the namespace exists
self.type_importer.import_pdb_type_into_ghidra(signature.class_type) self.type_importer.import_pdb_type_into_ghidra(self.signature.class_type)
assert match_info.name is not None assert self.match_info.name is not None
colon_split = sanitize_name(match_info.name).split("::") colon_split = sanitize_name(self.match_info.name).split("::")
self.name = colon_split.pop() self.name = colon_split.pop()
namespace_hierachy = colon_split namespace_hierachy = colon_split
self.namespace = get_ghidra_namespace(api, namespace_hierachy) self.namespace = get_ghidra_namespace(api, namespace_hierachy)
self.return_type = type_importer.import_pdb_type_into_ghidra( self.return_type = type_importer.import_pdb_type_into_ghidra(
signature.return_type self.signature.return_type
) )
self.arguments = [ self.arguments = [
ParameterImpl( ParameterImpl(
@ -66,7 +64,7 @@ def __init__(
type_importer.import_pdb_type_into_ghidra(type_name), type_importer.import_pdb_type_into_ghidra(type_name),
api.getCurrentProgram(), api.getCurrentProgram(),
) )
for (index, type_name) in enumerate(signature.arglist) for (index, type_name) in enumerate(self.signature.arglist)
] ]
@property @property
@ -90,7 +88,10 @@ def matches_ghidra_function(self, ghidra_function: Function) -> bool:
self.signature.call_type == ghidra_function.getCallingConventionName() self.signature.call_type == ghidra_function.getCallingConventionName()
) )
if thiscall_matches: if self.is_stub:
# We do not import the argument list for stubs, so it should be excluded in matches
args_match = True
elif thiscall_matches:
if self.signature.call_type == "__thiscall": if self.signature.call_type == "__thiscall":
args_match = self._matches_thiscall_parameters(ghidra_function) args_match = self._matches_thiscall_parameters(ghidra_function)
else: else:
@ -104,7 +105,7 @@ def matches_ghidra_function(self, ghidra_function: Function) -> bool:
name_match, name_match,
return_type_match, return_type_match,
thiscall_matches, thiscall_matches,
args_match, "ignored" if self.is_stub else args_match,
) )
return ( return (
@ -165,16 +166,25 @@ def overwrite_ghidra_function(self, ghidra_function: Function):
ghidra_function.setReturnType(self.return_type, SourceType.USER_DEFINED) ghidra_function.setReturnType(self.return_type, SourceType.USER_DEFINED)
ghidra_function.setCallingConvention(self.call_type) ghidra_function.setCallingConvention(self.call_type)
if self.is_stub:
logger.debug(
"%s is a stub, skipping parameter import", self.get_full_name()
)
return
ghidra_function.replaceParameters( ghidra_function.replaceParameters(
Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS, Function.FunctionUpdateType.DYNAMIC_STORAGE_ALL_PARAMS,
True, True, # force
SourceType.USER_DEFINED, SourceType.USER_DEFINED,
self.arguments, self.arguments,
) )
# When we set the parameters, Ghidra will generate the layout. self._import_parameter_names(ghidra_function)
# Now we read them again and match them against the stack layout in the PDB,
# both to verify and to set the parameter names. def _import_parameter_names(self, ghidra_function: Function):
# When we call `ghidra_function.replaceParameters`, Ghidra will generate the layout.
# Now we read the parameters again and match them against the stack layout in the PDB,
# both to verify the layout and to set the parameter names.
ghidra_parameters: list[Parameter] = ghidra_function.getParameters() ghidra_parameters: list[Parameter] = ghidra_function.getParameters()
# Try to add Ghidra function names # Try to add Ghidra function names
@ -188,7 +198,9 @@ def overwrite_ghidra_function(self, ghidra_function: Function):
# Appears to never happen - could in theory be relevant to __fastcall__ functions, # Appears to never happen - could in theory be relevant to __fastcall__ functions,
# which we haven't seen yet # which we haven't seen yet
logger.warning("Unhandled register variable in %s", self.get_full_name) logger.warning(
"Unhandled register variable in %s", self.get_full_name()
)
continue continue
def _rename_stack_parameter(self, index: int, param: Parameter): def _rename_stack_parameter(self, index: int, param: Parameter):

View file

@ -38,6 +38,13 @@ class FunctionSignature:
stack_symbols: list[CppStackOrRegisterSymbol] stack_symbols: list[CppStackOrRegisterSymbol]
@dataclass
class PdbFunction:
match_info: MatchInfo
signature: FunctionSignature
is_stub: bool
class PdbFunctionExtractor: class PdbFunctionExtractor:
""" """
Extracts all information on a given function from the parsed PDB Extracts all information on a given function from the parsed PDB
@ -121,20 +128,18 @@ def get_func_signature(self, fn: SymbolsEntry) -> Optional[FunctionSignature]:
stack_symbols=stack_symbols, stack_symbols=stack_symbols,
) )
def get_function_list(self) -> list[tuple[MatchInfo, FunctionSignature]]: def get_function_list(self) -> list[PdbFunction]:
handled = ( handled = (
self.handle_matched_function(match) self.handle_matched_function(match)
for match in self.compare.get_functions() for match in self.compare.get_functions()
) )
return [signature for signature in handled if signature is not None] return [signature for signature in handled if signature is not None]
def handle_matched_function( def handle_matched_function(self, match_info: MatchInfo) -> Optional[PdbFunction]:
self, match_info: MatchInfo
) -> Optional[tuple[MatchInfo, FunctionSignature]]:
assert match_info.orig_addr is not None assert match_info.orig_addr is not None
match_options = self.compare.get_match_options(match_info.orig_addr) match_options = self.compare.get_match_options(match_info.orig_addr)
assert match_options is not None assert match_options is not None
if match_options.get("skip", False) or match_options.get("stub", False): if match_options.get("skip", False):
return None return None
function_data = next( function_data = next(
@ -163,4 +168,6 @@ def handle_matched_function(
if function_signature is None: if function_signature is None:
return None return None
return match_info, function_signature is_stub = match_options.get("stub", False)
return PdbFunction(match_info, function_signature, is_stub)