diff --git a/tools/ghidra_scripts/lego_util/function_importer.py b/tools/ghidra_scripts/lego_util/function_importer.py index bf99f7f8..80176fc5 100644 --- a/tools/ghidra_scripts/lego_util/function_importer.py +++ b/tools/ghidra_scripts/lego_util/function_importer.py @@ -17,6 +17,7 @@ CppStackSymbol, ) from lego_util.ghidra_helper import ( + add_pointer_type, get_ghidra_namespace, sanitize_name, ) @@ -82,7 +83,26 @@ def matches_ghidra_function(self, ghidra_function: Function) -> bool: """Checks whether this function declaration already matches the description in Ghidra""" name_match = self.name == ghidra_function.getName(False) namespace_match = self.namespace == ghidra_function.getParentNamespace() - return_type_match = self.return_type == ghidra_function.getReturnType() + ghidra_return_type = ghidra_function.getReturnType() + return_type_match = self.return_type == ghidra_return_type + + # Handle edge case: Return type X that is larger than the return register. + # In that case, the function returns `X*` and has another argument `X* __return_storage_ptr`. + if ( + (not return_type_match) + and (self.return_type.getLength() > 4) + and (add_pointer_type(self.api, self.return_type) == ghidra_return_type) + and any( + param + for param in ghidra_function.getParameters() + if param.getName() == "__return_storage_ptr__" + ) + ): + logger.debug( + "%s has a return type larger than 4 bytes", self.get_full_name() + ) + return_type_match = True + # match arguments: decide if thiscall or not thiscall_matches = ( self.signature.call_type == ghidra_function.getCallingConventionName() @@ -128,6 +148,14 @@ def _matches_thiscall_parameters(self, ghidra_function: Function) -> bool: return self._parameter_lists_match(ghidra_params) def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool: + # Remove return storage pointer from comparison if present. + # This is relevant to returning values larger than 4 bytes, and is not mentioned in the PDB + ghidra_params = [ + param + for param in ghidra_params + if param.getName() != "__return_storage_ptr__" + ] + if len(self.arguments) != len(ghidra_params): logger.info("Mismatching argument count") return False @@ -146,11 +174,16 @@ def _parameter_lists_match(self, ghidra_params: "list[Parameter]") -> bool: if stack_match is None: logger.debug("Not found on stack: %s", ghidra_arg) return False - # "__formal" is the placeholder for arguments without a name - if ( - stack_match.name != ghidra_arg.getName() - and not stack_match.name.startswith("__formal") - ): + + if stack_match.name.startswith("__formal"): + # "__formal" is the placeholder for arguments without a name + continue + + if stack_match.name == "__$ReturnUdt": + # These appear in templates and cannot be set automatically, as they are a NOTYPE + continue + + if stack_match.name != ghidra_arg.getName(): logger.debug( "Argument name mismatch: expected %s, found %s", stack_match.name, diff --git a/tools/ghidra_scripts/lego_util/type_importer.py b/tools/ghidra_scripts/lego_util/type_importer.py index 0d3ee5df..c645ebf8 100644 --- a/tools/ghidra_scripts/lego_util/type_importer.py +++ b/tools/ghidra_scripts/lego_util/type_importer.py @@ -1,5 +1,5 @@ import logging -from typing import Any +from typing import Any, Callable, TypeVar # Disable spurious warnings in vscode / pylance # pyright: reportMissingModuleSource=false @@ -29,6 +29,7 @@ CategoryPath, DataType, DataTypeConflictHandler, + Enum, EnumDataType, StructureDataType, StructureInternal, @@ -47,7 +48,9 @@ def __init__(self, api: FlatProgramAPI, extraction: PdbFunctionExtractor): self.extraction = extraction # tracks the structs/classes we have already started to import, otherwise we run into infinite recursion self.handled_structs: set[str] = set() - self.struct_call_stack: list[str] = [] + + # tracks the enums we have already handled for the sake of efficiency + self.handled_enums: dict[str, Enum] = {} @property def types(self): @@ -166,9 +169,13 @@ def _import_enum(self, type_pdb: dict[str, Any]) -> DataType: field_list = self.extraction.compare.cv.types.keys.get(type_pdb["field_type"]) assert field_list is not None, f"Failed to find field list for enum {type_pdb}" - result = EnumDataType( - CategoryPath("/imported"), type_pdb["name"], underlying_type.getLength() + result = self._get_or_create_enum_data_type( + type_pdb["name"], underlying_type.getLength() ) + # clear existing variant if there are any + for existing_variant in result.getNames(): + result.remove(existing_variant) + variants: list[dict[str, Any]] = field_list["variants"] for variant in variants: result.add(variant["name"], variant["value"]) @@ -259,30 +266,74 @@ def _get_or_create_namespace(self, class_name_with_namespace: str): parent_namespace = create_ghidra_namespace(self.api, colon_split) self.api.createClass(parent_namespace, class_name) + def _get_or_create_enum_data_type( + self, enum_type_name: str, enum_type_size: int + ) -> Enum: + if (known_enum := self.handled_enums.get(enum_type_name, None)) is not None: + return known_enum + + result = self._get_or_create_data_type( + enum_type_name, + "enum", + Enum, + lambda: EnumDataType( + CategoryPath("/imported"), enum_type_name, enum_type_size + ), + ) + self.handled_enums[enum_type_name] = result + return result + def _get_or_create_struct_data_type( self, class_name_with_namespace: str, class_size: int ) -> StructureInternal: + return self._get_or_create_data_type( + class_name_with_namespace, + "class/struct", + StructureInternal, + lambda: StructureDataType( + CategoryPath("/imported"), class_name_with_namespace, class_size + ), + ) + + T = TypeVar("T", bound=DataType) + + def _get_or_create_data_type( + self, + type_name: str, + readable_name_of_type_category: str, + expected_type: type[T], + new_instance_callback: Callable[[], T], + ) -> T: + """ + Checks if a data type provided under the given name exists in Ghidra. + Creates one using `new_instance_callback` if there is not. + Also verifies the data type. + + Note that the return value of `addDataType()` is not the same instance as the input + even if there is no name collision. + """ try: - data_type = get_ghidra_type(self.api, class_name_with_namespace) + data_type = get_ghidra_type(self.api, type_name) logger.debug( - "Found existing data type %s under category path %s", - class_name_with_namespace, + "Found existing %s type %s under category path %s", + readable_name_of_type_category, + type_name, data_type.getCategoryPath(), ) except TypeNotFoundInGhidraError: - # Create a new struct data type - data_type = StructureDataType( - CategoryPath("/imported"), class_name_with_namespace, class_size - ) data_type = ( self.api.getCurrentProgram() .getDataTypeManager() - .addDataType(data_type, DataTypeConflictHandler.KEEP_HANDLER) + .addDataType( + new_instance_callback(), DataTypeConflictHandler.KEEP_HANDLER + ) + ) + logger.info( + "Created new %s data type %s", readable_name_of_type_category, type_name ) - logger.info("Created new data type %s", class_name_with_namespace) assert isinstance( - data_type, StructureInternal - ), f"Found type sharing its name with a class/struct, but is not a struct: {class_name_with_namespace}" + data_type, expected_type + ), f"Found existing type named {type_name} that is not a {readable_name_of_type_category}" return data_type def _delete_and_recreate_struct_data_type(