sh-dis/ast_transformers.py

242 lines
7.8 KiB
Python

from pprint import pprint
from collections import defaultdict
from parser import Tree
from lexer import Identifier, Punctuator, IntegerConstant
import identifier_substitution
def find_locals__walk_assignment_lhs(tree):
if type(tree) is Tree:
for child in tree.children:
yield from find_locals__walk_assignment_lhs(child)
elif type(tree) is Identifier:
token = tree
if token.token not in {'m', 'n', 'i', 'd'}:
if token.token.lower() == token.token:
assert token.token not in identifier_substitution.mapping, token.token
yield token.token
def find_locals__walk_assignment(tree):
if type(tree) is Tree:
if tree.operation == "assignment":
for name in find_locals__walk_assignment_lhs(tree.children[0]):
yield name, tree.children[1]
for child in tree.children[1:]:
yield from find_locals__walk_assignment(child)
else:
for child in tree.children:
yield from find_locals__walk_assignment(child)
def transform_assignment_list__collect_identifiers(tree, operation):
if type(tree) == Tree:
assert tree.operation == operation, pprint(tree)
for child in tree.children:
yield from transform_assignment_list__collect_identifiers(child, operation)
elif type(tree) == Identifier:
yield tree.token
def transform_assignment_list__assignment(tree):
assert tree.operation == "assignment", tree
# first, collect the lhs
if type(tree.children[0]) is not Tree or tree.children[0].operation != 'assignment_list':
return tree
lhs = transform_assignment_list__collect_identifiers(tree.children[0], "assignment_list")
assert tree.children[1].operation == "function_call", (tree.children[1].operation, pprint(tree))
function_call = tree.children[1]
rhs = list(transform_assignment_list__collect_identifiers(function_call.children[1], "argument_list"))
common = []
lhs_only = []
for l_token in lhs:
if l_token in rhs:
common.append(l_token)
else:
lhs_only.append(l_token)
def gen_argument_list(tree):
if type(tree) is Tree:
assert tree.operation == 'argument_list'
return Tree(
operation=tree.operation,
children=[gen_argument_list(child) for child in tree.children]
)
elif type(tree) is Identifier:
if tree.token in common:
return Tree(
operation="unary_reference",
children=[tree]
)
else:
return tree
else:
return tree
def gen_function_call():
return Tree(
operation="function_call",
children=[
function_call.children[0],
gen_argument_list(function_call.children[1]),
]
)
if len(lhs_only) == 0:
return gen_function_call()
elif len(lhs_only) == 1:
return Tree(
operation="assignment",
children=[Identifier(line=-1, token=lhs_only[0]), gen_function_call()]
)
else:
assert False, (lhs_only, common, pprint(tree))
def transform_assignment_list(tree):
if type(tree) is Tree:
if tree.operation == "assignment":
return transform_assignment_list__assignment(tree)
else:
return Tree(
operation=tree.operation,
children=[
transform_assignment_list(child)
for child in tree.children
]
)
else:
return tree
function_types = {
"FloatValue32": "float32_t",
"FloatValue64": "float64_t",
"FLOAT_LS": "float32_t",
"FLOAT_LD": "float64_t",
"FCNV_DS": "uint32_t",
"FCNV_SD": "float64_t",
}
name_types = {
"fps": "uint32_t",
"sr": "uint32_t",
}
def guess_type(name, tree, declared):
if name in name_types:
return name_types[name]
elif type(tree) is Tree and tree.operation == 'function_call':
assert type(tree.children[0]) is Identifier, tree
function_name = tree.children[0].token
if function_name in function_types:
return function_types[function_name]
elif type(tree) is Identifier and tree.token in declared:
return declared[tree.token]
elif type(tree) is IntegerConstant:
if tree.token.lower() in {"0x00000000", "0x3f800000"}:
# hack for fldi0/fldi1
return "float32_t"
# fallback
return 'int64_t'
def transform_local_declarations(statements):
def all_locals():
for statement in statements:
yield from find_locals__walk_assignment(statement)
declared = dict()
set_locals = defaultdict(list)
for name, tree in all_locals():
if name in declared:
continue
identifier_type = guess_type(name, tree, declared)
declared[name] = identifier_type
set_locals[identifier_type].append(Identifier(line=-1, token=name))
for identifier_type, identifiers in set_locals.items():
yield Tree(operation="expression_statement",
children=[Tree(operation="declaration",
children=[Identifier(line=-1, token=identifier_type), *identifiers])])
def transform_identifiers(tree, parent):
if type(tree) is Tree:
return Tree(
operation=tree.operation,
children=[transform_identifiers(child, tree) for child in tree.children]
)
elif type(tree) is Identifier:
token = tree
if token.token in identifier_substitution.mapping:
new_name = identifier_substitution.mapping[token.token]
if token.token == 'FPSCR':
assert type(parent) is Tree, parent
if parent.operation == 'member':
new_name = 'state->fpscr.bits'
return Identifier(
line=token.line,
token=new_name
)
else:
return token
else:
return tree
require_extra_arguments = {
"IsDelaySlot": "state",
"SLEEP": "state",
"OCBP" : "state",
"WriteMemory8" : "map",
"WriteMemory16": "map",
"WriteMemory32": "map",
"ReadMemory8" : "map",
"ReadMemory16" : "map",
"ReadMemory32" : "map",
"WriteMemoryPair32": "map",
"ReadMemoryPair32" : "map",
}
def transform_function_arguments(tree):
def arguments(arg):
identifier = Identifier(line=tree.children[0].line, token=arg)
if len(tree.children) == 1:
return identifier
else:
assert len(tree.children) == 2, tree
return Tree(
operation='argument_list',
children=[
identifier,
transform_function_arguments(tree.children[1])
]
)
if type(tree) is Tree:
if tree.operation == "function_call":
assert type(tree.children[0]) is Identifier
if tree.children[0].token in require_extra_arguments:
return Tree(
operation=tree.operation,
children=[
tree.children[0],
arguments(require_extra_arguments[tree.children[0].token])
]
)
return Tree(
operation=tree.operation,
children=[transform_function_arguments(child) for child in tree.children]
)
else:
return tree
def transform_statements(statements):
yield from transform_local_declarations(statements)
for statement in statements:
statement = transform_assignment_list(statement)
statement = transform_function_arguments(statement)
statement = transform_identifiers(statement, None)
yield statement