Improve python tools (#273)

* Use python3 features

* Use `with` statement for file access
* Use f-strings instead of modulo string formatting
* Single quotes in most places

Fix typo in 'with' statement

* Add files into missing messages

* Fix can_resolve_register_differences and round percentages

* Return modified value instead of relying on in-place modification
This commit is contained in:
Thomas Phillips 2023-11-08 22:47:11 +13:00 committed by GitHub
parent 42c47a6540
commit bd85abaf2a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 190 additions and 201 deletions

View file

@ -50,19 +50,19 @@
original = args.original original = args.original
if not os.path.isfile(original): if not os.path.isfile(original):
parser.error('Original binary does not exist') parser.error(f'Original binary {original} does not exist')
recomp = args.recompiled recomp = args.recompiled
if not os.path.isfile(recomp): if not os.path.isfile(recomp):
parser.error('Recompiled binary does not exist') parser.error(f'Recompiled binary {recomp} does not exist')
syms = args.pdb syms = args.pdb
if not os.path.isfile(syms): if not os.path.isfile(syms):
parser.error('Symbols PDB does not exist') parser.error(f'Symbols PDB {syms} does not exist')
source = args.decomp_dir source = args.decomp_dir
if not os.path.isdir(source): if not os.path.isdir(source):
parser.error('Source directory does not exist') parser.error(f'Source directory {source} does not exist')
svg = args.svg svg = args.svg
@ -70,7 +70,7 @@
# to file addresses # to file addresses
class Bin: class Bin:
def __init__(self, filename): def __init__(self, filename):
logger.debug('Parsing headers of "%s"... ', filename) logger.debug(f'Parsing headers of "{filename}"... ')
self.file = open(filename, 'rb') self.file = open(filename, 'rb')
#HACK: Strictly, we should be parsing the header, but we know where #HACK: Strictly, we should be parsing the header, but we know where
@ -153,8 +153,8 @@ def __init__(self, pdb, file, wine_path_converter):
else: else:
call.append(pdb) call.append(pdb)
logger.info('Parsing %s ...', pdb) logger.info(f'Parsing {pdb} ...')
logger.debug('Command = %r', call) logger.debug(f'Command = {call}')
line_dump = subprocess.check_output(call).decode('utf-8').split('\r\n') line_dump = subprocess.check_output(call).decode('utf-8').split('\r\n')
current_section = None current_section = None
@ -219,7 +219,7 @@ def get_recompiled_address(self, filename, line):
addr = None addr = None
found = False found = False
logger.debug('Looking for %s:%d', filename, line) logger.debug(f'Looking for {filename}:{line}')
filename_basename = os.path.basename(filename).lower() filename_basename = os.path.basename(filename).lower()
for fn in self.lines: for fn in self.lines:
@ -239,9 +239,9 @@ def get_recompiled_address(self, filename, line):
if addr in self.funcs: if addr in self.funcs:
return self.funcs[addr] return self.funcs[addr]
else: else:
logger.error('Failed to find function symbol with address: 0x%x', addr) logger.error(f'Failed to find function symbol with address: 0x{addr:x}')
else: else:
logger.error('Failed to find function symbol with filename and line: %s:%d', filename, line) logger.error(f'Failed to find function symbol with filename and line: {filename}:{line}')
def get_recompiled_address_from_name(self, name): def get_recompiled_address_from_name(self, name):
logger.debug('Looking for %s', name) logger.debug('Looking for %s', name)
@ -249,7 +249,7 @@ def get_recompiled_address_from_name(self, name):
if name in self.names: if name in self.names:
return self.names[name] return self.names[name]
else: else:
logger.error('Failed to find function symbol with name: %s', name) logger.error(f'Failed to find function symbol with name: {name}')
wine_path_converter = None wine_path_converter = None
if os.name != 'nt': if os.name != 'nt':
@ -272,7 +272,7 @@ def get(self, addr):
return self.replacements[addr] return self.replacements[addr]
else: else:
self.counter += 1 self.counter += 1
replacement = '<OFFSET%d>' % self.counter replacement = f'<OFFSET{self.counter}>'
self.replacements[addr] = replacement self.replacements[addr] = replacement
return replacement return replacement
@ -332,12 +332,26 @@ def parse_asm(file, addr, size):
if op_str is None: if op_str is None:
asm.append(mnemonic) asm.append(mnemonic)
else: else:
asm.append("%s %s" % (mnemonic, op_str)) asm.append(f'{mnemonic} {op_str}')
return asm return asm
REGISTER_LIST = set([ REGISTER_LIST = set([
'eax', 'ebx', 'ecx', 'edx', 'edi', 'esi', 'ebp', 'esp', 'ax',
'ax', 'bx', 'cx', 'dx', 'di', 'si', 'bp', 'sp', 'bp',
'bx',
'cx',
'di',
'dx',
'eax',
'ebp',
'ebx',
'ecx',
'edi',
'edx',
'esi',
'esp',
'si',
'sp',
]) ])
WORDS = re.compile(r'\w+') WORDS = re.compile(r'\w+')
@ -350,9 +364,8 @@ def get_registers(line: str):
to_replace.append((reg, match.start())) to_replace.append((reg, match.start()))
return to_replace return to_replace
def replace_register(lines: list[str], start_line: int, reg: str, replacement: str): def replace_register(lines: list[str], start_line: int, reg: str, replacement: str) -> list[str]:
for i in range(start_line, len(lines)): return [line.replace(reg, replacement) if i >= start_line else line for i, line in enumerate(lines)]
lines[i] = lines[i].replace(reg, replacement)
# Is it possible to make new_asm the same as original_asm by swapping registers? # Is it possible to make new_asm the same as original_asm by swapping registers?
def can_resolve_register_differences(original_asm, new_asm): def can_resolve_register_differences(original_asm, new_asm):
@ -379,10 +392,10 @@ def can_resolve_register_differences(original_asm, new_asm):
if replacing_reg in REGISTER_LIST: if replacing_reg in REGISTER_LIST:
if replacing_reg != reg: if replacing_reg != reg:
# Do a three-way swap replacing in all the subsequent lines # Do a three-way swap replacing in all the subsequent lines
temp_reg = "&" * len(reg) temp_reg = '&' * len(reg)
replace_register(new_asm, i, replacing_reg, temp_reg) new_asm = replace_register(new_asm, i, replacing_reg, temp_reg)
replace_register(new_asm, i, reg, replacing_reg) new_asm = replace_register(new_asm, i, reg, replacing_reg)
replace_register(new_asm, i, temp_reg, reg) new_asm = replace_register(new_asm, i, temp_reg, reg)
else: else:
# No replacement to do, different code, bail out # No replacement to do, different code, bail out
return False return False
@ -404,185 +417,171 @@ def can_resolve_register_differences(original_asm, new_asm):
for subdir, dirs, files in os.walk(source): for subdir, dirs, files in os.walk(source):
for file in files: for file in files:
srcfilename = os.path.join(os.path.abspath(subdir), file) srcfilename = os.path.join(os.path.abspath(subdir), file)
srcfile = open(srcfilename, 'r') with open(srcfilename, 'r') as srcfile:
line_no = 0 line_no = 0
while True: while True:
try: try:
line = srcfile.readline() line = srcfile.readline()
line_no += 1 line_no += 1
if not line: if not line:
break break
line = line.strip() line = line.strip()
if line.startswith(pattern) and not line.endswith("STUB"): if line.startswith(pattern) and not line.endswith('STUB'):
par = line[len(pattern):].strip().split() par = line[len(pattern):].strip().split()
module = par[0] module = par[0]
if module != basename: if module != basename:
continue
addr = int(par[1], 16)
# Verbose flag handling
if verbose:
if addr == verbose:
found_verbose_target = True
else:
continue continue
if line.endswith("TEMPLATE"): addr = int(par[1], 16)
line = srcfile.readline()
line_no += 1
# Name comes after // comment
name = line.strip()[2:].strip()
recinfo = syminfo.get_recompiled_address_from_name(name) # Verbose flag handling
if not recinfo:
continue
else:
find_open_bracket = line
while '{' not in find_open_bracket:
find_open_bracket = srcfile.readline()
line_no += 1
recinfo = syminfo.get_recompiled_address(srcfilename, line_no)
if not recinfo:
continue
# The effective_ratio is the ratio when ignoring differing register
# allocation vs the ratio is the true ratio.
ratio = 0.0
effective_ratio = 0.0
if recinfo.size:
origasm = parse_asm(origfile, addr + recinfo.start, recinfo.size)
recompasm = parse_asm(recompfile, recinfo.addr + recinfo.start, recinfo.size)
diff = difflib.SequenceMatcher(None, origasm, recompasm)
ratio = diff.ratio()
effective_ratio = ratio
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
if can_resolve_register_differences(origasm, recompasm):
effective_ratio = 1.0
else:
ratio = 0
percenttext = "%.2f%%" % (effective_ratio * 100)
if not plain:
if effective_ratio == 1.0:
percenttext = colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL
elif effective_ratio > 0.8:
percenttext = colorama.Fore.YELLOW + percenttext + colorama.Style.RESET_ALL
else:
percenttext = colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL
if effective_ratio == 1.0 and ratio != 1.0:
if plain:
percenttext += "*"
else:
percenttext += colorama.Fore.RED + "*" + colorama.Style.RESET_ALL
if args.print_rec_addr:
addrs = '%s / %s' % (hex(addr), hex(recinfo.addr))
else:
addrs = hex(addr)
if not verbose:
print(' %s (%s) is %s similar to the original' % (recinfo.name, addrs, percenttext))
function_count += 1
total_accuracy += ratio
total_effective_accuracy += effective_ratio
if recinfo.size:
udiff = difflib.unified_diff(origasm, recompasm, n=10)
# If verbose, print the diff for that function to the output
if verbose: if verbose:
if effective_ratio == 1.0: if addr == verbose:
ok_text = "OK!" if plain else (colorama.Fore.GREEN + "✨ OK! ✨" + colorama.Style.RESET_ALL) found_verbose_target = True
if ratio == 1.0:
print("%s: %s 100%% match.\n\n%s\n\n" %
(addrs, recinfo.name, ok_text))
else:
print("%s: %s Effective 100%% match. (Differs in register allocation only)\n\n%s (still differs in register allocation)\n\n" %
(addrs, recinfo.name, ok_text))
else: else:
for line in udiff: continue
if line.startswith("++") or line.startswith("@@") or line.startswith("--"):
# Skip unneeded parts of the diff for the brief view if line.endswith('TEMPLATE'):
pass line = srcfile.readline()
elif line.startswith("+"): line_no += 1
if plain: # Name comes after // comment
print(line) name = line.strip()[2:].strip()
else:
print(colorama.Fore.GREEN + line) recinfo = syminfo.get_recompiled_address_from_name(name)
elif line.startswith("-"): if not recinfo:
if plain: continue
print(line) else:
else: find_open_bracket = line
print(colorama.Fore.RED + line) while '{' not in find_open_bracket:
find_open_bracket = srcfile.readline()
line_no += 1
recinfo = syminfo.get_recompiled_address(srcfilename, line_no)
if not recinfo:
continue
# The effective_ratio is the ratio when ignoring differing register
# allocation vs the ratio is the true ratio.
ratio = 0.0
effective_ratio = 0.0
if recinfo.size:
origasm = parse_asm(origfile, addr + recinfo.start, recinfo.size)
recompasm = parse_asm(recompfile, recinfo.addr + recinfo.start, recinfo.size)
diff = difflib.SequenceMatcher(None, origasm, recompasm)
ratio = diff.ratio()
effective_ratio = ratio
if ratio != 1.0:
# Check whether we can resolve register swaps which are actually
# perfect matches modulo compiler entropy.
if can_resolve_register_differences(origasm, recompasm):
effective_ratio = 1.0
else:
ratio = 0
percenttext = f'{(effective_ratio * 100):.2f}%'
if not plain:
if effective_ratio == 1.0:
percenttext = colorama.Fore.GREEN + percenttext + colorama.Style.RESET_ALL
elif effective_ratio > 0.8:
percenttext = colorama.Fore.YELLOW + percenttext + colorama.Style.RESET_ALL
else:
percenttext = colorama.Fore.RED + percenttext + colorama.Style.RESET_ALL
if effective_ratio == 1.0 and ratio != 1.0:
if plain:
percenttext += '*'
else:
percenttext += colorama.Fore.RED + '*' + colorama.Style.RESET_ALL
if args.print_rec_addr:
addrs = f'0x{addr:x} / 0x{recinfo.addr:x}'
else:
addrs = hex(addr)
if not verbose:
print(f' {recinfo.name} ({addrs}) is {percenttext} similar to the original')
function_count += 1
total_accuracy += ratio
total_effective_accuracy += effective_ratio
if recinfo.size:
udiff = difflib.unified_diff(origasm, recompasm, n=10)
# If verbose, print the diff for that function to the output
if verbose:
if effective_ratio == 1.0:
ok_text = 'OK!' if plain else (colorama.Fore.GREEN + '✨ OK! ✨' + colorama.Style.RESET_ALL)
if ratio == 1.0:
print(f'{addrs}: {recinfo.name} 100% match.\n\n{ok_text}\n\n')
else: else:
print(line) print(f'{addrs}: {recinfo.name} Effective 100%% match. (Differs in register allocation only)\n\n{ok_text} (still differs in register allocation)\n\n')
if not plain: else:
print(colorama.Style.RESET_ALL, end='') for line in udiff:
if line.startswith('++') or line.startswith('@@') or line.startswith('--'):
# Skip unneeded parts of the diff for the brief view
pass
elif line.startswith('+'):
if plain:
print(line)
else:
print(colorama.Fore.GREEN + line)
elif line.startswith('-'):
if plain:
print(line)
else:
print(colorama.Fore.RED + line)
else:
print(line)
if not plain:
print(colorama.Style.RESET_ALL, end='')
print("\n%s is only %s similar to the original, diff above" % (recinfo.name, percenttext)) print(f'\n{recinfo.name} is only {percenttext} similar to the original, diff above')
# If html, record the diffs to an HTML file # If html, record the diffs to an HTML file
if html_path: if html_path:
escaped = '\\n'.join(udiff).replace('"', '\\"').replace('\n', '\\n').replace('<', '&lt;').replace('>', '&gt;') escaped = html.escape('\\n'.join(udiff).replace('"', '\\"').replace('\n', '\\n'))
htmlinsert.append('{address: "%s", name: "%s", matching: %s, diff: "%s"}' % (hex(addr), html.escape(recinfo.name), str(effective_ratio), escaped)) htmlinsert.append(f'{{address: "0x{addr:x}", name: "{html.escape(recinfo.name)}", matching: {effective_ratio}, diff: "{escaped}"}}')
except UnicodeDecodeError: except UnicodeDecodeError:
break break
def gen_html(html_path, data): def gen_html(html_path, data):
templatefile = open(get_file_in_script_dir('template.html'), 'r') templatedata = None
if not templatefile: with open(get_file_in_script_dir('template.html')) as templatefile:
print('Failed to find HTML template file, can\'t generate HTML summary') templatedata = templatefile.read()
return
templatedata = templatefile.read()
templatefile.close()
templatedata = templatedata.replace('/* INSERT DATA HERE */', ','.join(data), 1) templatedata = templatedata.replace('/* INSERT DATA HERE */', ','.join(data), 1)
htmlfile = open(html_path, 'w') with open(html_path, 'w') as htmlfile:
if not htmlfile: htmlfile.write(templatedata)
print('Failed to write to HTML file %s' % html_path)
return
htmlfile.write(templatedata)
htmlfile.close()
def gen_svg(svg, name, icon, implemented_funcs, total_funcs, raw_accuracy): def gen_svg(svg, name, icon, implemented_funcs, total_funcs, raw_accuracy):
templatefile = open(get_file_in_script_dir('template.svg'), 'r') templatedata = None
if not templatefile: with open(get_file_in_script_dir('template.svg')) as templatefile:
print('Failed to find SVG template file, can\'t generate SVG summary') templatedata = templatefile.read()
return
templatedata = templatefile.read()
templatefile.close()
# TODO: Use templating engine (e.g. pystache)
# Replace icon # Replace icon
if args.svg_icon: if args.svg_icon:
iconfile = open(args.svg_icon, 'rb') with open(args.svg_icon, 'rb') as iconfile:
templatedata = templatedata.replace('{icon}', base64.b64encode(iconfile.read()).decode('utf-8'), 1) templatedata = templatedata.replace('{icon}', base64.b64encode(iconfile.read()).decode('utf-8'), 1)
iconfile.close()
# Replace name # Replace name
templatedata = templatedata.replace('{name}', name, 1) templatedata = templatedata.replace('{name}', name, 1)
# Replace implemented statistic # Replace implemented statistic
templatedata = templatedata.replace('{implemented}', '%.2f%% (%i/%i)' % (implemented_funcs / total_funcs * 100, implemented_funcs, total_funcs), 1) templatedata = templatedata.replace('{implemented}', f'{(implemented_funcs / total_funcs * 100):.2f}% ({implemented_funcs}/{total_funcs})', 1)
# Replace accuracy statistic # Replace accuracy statistic
templatedata = templatedata.replace('{accuracy}', '%.2f%%' % (raw_accuracy / implemented_funcs * 100), 1) templatedata = templatedata.replace('{accuracy}', f'{(raw_accuracy / implemented_funcs * 100):.2f}%', 1)
# Generate progress bar width # Generate progress bar width
total_statistic = raw_accuracy / total_funcs total_statistic = raw_accuracy / total_funcs
@ -593,22 +592,17 @@ def gen_svg(svg, name, icon, implemented_funcs, total_funcs, raw_accuracy):
templatedata = templatedata[0:percentstart] + str(progwidth) + templatedata[percentend + 1:] templatedata = templatedata[0:percentstart] + str(progwidth) + templatedata[percentend + 1:]
# Replace percentage statistic # Replace percentage statistic
templatedata = templatedata.replace('{percent}', '%.2f%%' % (total_statistic * 100), 2) templatedata = templatedata.replace('{percent}', f'{(total_statistic * 100):.2f}%', 2)
svgfile = open(svg, 'w') with open(svg, 'w') as svgfile:
if not svgfile: svgfile.write(templatedata)
print('Failed to write to SVG file %s' % svg)
return
svgfile.write(templatedata)
svgfile.close()
if html_path: if html_path:
gen_html(html_path, htmlinsert) gen_html(html_path, htmlinsert)
if verbose: if verbose:
if not found_verbose_target: if not found_verbose_target:
print('Failed to find the function with address %s' % hex(verbose)) print(f'Failed to find the function with address 0x{verbose:x}')
else: else:
implemented_funcs = function_count implemented_funcs = function_count
@ -616,8 +610,9 @@ def gen_svg(svg, name, icon, implemented_funcs, total_funcs, raw_accuracy):
function_count = int(args.total) function_count = int(args.total)
if function_count > 0: if function_count > 0:
print('\nTotal effective accuracy %.2f%% across %i functions (%.2f%% actual accuracy)' % effective_accuracy = total_effective_accuracy / function_count * 100
(total_effective_accuracy / function_count * 100, function_count, total_accuracy / function_count * 100)) actual_accuracy = total_accuracy / function_count * 100
print(f'\nTotal effective accuracy {effective_accuracy:.2f}% across {function_count} functions ({actual_accuracy:.2f}% actual accuracy)')
if svg: if svg:
gen_svg(svg, os.path.basename(original), args.svg_icon, implemented_funcs, function_count, total_effective_accuracy) gen_svg(svg, os.path.basename(original), args.svg_icon, implemented_funcs, function_count, total_effective_accuracy)

View file

@ -16,10 +16,10 @@
args = parser.parse_args() args = parser.parse_args()
if not os.path.isfile(args.original): if not os.path.isfile(args.original):
parser.error('Original binary does not exist') parser.error(f'Original binary file {args.original} does not exist')
if not os.path.isfile(args.recompiled): if not os.path.isfile(args.recompiled):
parser.error('Recompiled binary does not exist') parser.error(f'Recompiled binary {args.recompiled} does not exist')
def get_file_in_script_dir(fn): def get_file_in_script_dir(fn):
return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn) return os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), fn)
@ -58,25 +58,19 @@ def get_exports(file):
for line in udiff: for line in udiff:
has_diff = True has_diff = True
if line.startswith("++") or line.startswith("@@") or line.startswith("--"): color = ''
if line.startswith('++') or line.startswith('@@') or line.startswith('--'):
# Skip unneeded parts of the diff for the brief view # Skip unneeded parts of the diff for the brief view
pass continue
elif line.startswith("+"): # Work out color if we are printing color
if args.no_color: if not args.no_color:
print(line) if line.startswith('+'):
else: color = colorama.Fore.GREEN
print(colorama.Fore.GREEN + line) elif line.startswith('-'):
elif line.startswith("-"): color = colorama.Fore.RED
if args.no_color: print(color + line)
print(line) # Reset color if we're printing in color
else:
print(colorama.Fore.RED + line)
else:
print(line)
if not args.no_color: if not args.no_color:
print(colorama.Style.RESET_ALL, end='') print(colorama.Style.RESET_ALL, end='')
if has_diff: sys.exit(1 if has_diff else 0)
sys.exit(1)
else:
sys.exit(0)