assembler: implement fs emitter and frontend

This commit is contained in:
Zack Buhman 2025-10-20 18:21:41 -05:00
parent 72666a8c1f
commit efecb277c8
6 changed files with 253 additions and 15 deletions

View File

@ -0,0 +1,38 @@
import sys
from assembler.lexer import Lexer, LexerError
from assembler.fs.parser import Parser, ParserError
from assembler.fs.keywords import find_keyword
from assembler.fs.validator import validate_instruction, ValidatorError
from assembler.fs.emitter import emit_instruction
from assembler.error import print_error
def frontend_inner(buf):
lexer = Lexer(buf, find_keyword, emit_newlines=False)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
for ins_ast in parser.instructions():
ins = validate_instruction(ins_ast)
code = [0] * 6
emit_instruction(code, ins)
print("\n".join(f"0x{code[i]:08x}," for i in range(6)))
print()
def frontend(filename, buf):
try:
frontend_inner(buf)
except LexerError as e:
print_error(filename, buf, e)
raise
except ParserError as e:
print_error(filename, buf, e)
raise
except ValidatorError as e:
print_error(filename, buf, e)
raise
if __name__ == "__main__":
input_filename = sys.argv[1]
with open(input_filename, 'rb') as f:
buf = f.read()
frontend(input_filename, buf)

View File

@ -0,0 +1,178 @@
from os import path
from pprint import pprint
from functools import partial
import parse_bits
from assembler.fs.validator import SrcAddrType
class BaseRegister:
def set(self, code, value, *, code_ix, descriptor):
if type(descriptor.bits) is int:
mask = 1
low = descriptor.bits
else:
high, low = descriptor.bits
assert high > low
mask_length = (high - low) + 1
mask = (1 << mask_length) - 1
code_value = code[code_ix]
assert (code_value >> low) & mask == 0
assert value & mask == value
code[code_ix] |= (value & mask) << low
_descriptor_indicies = {
"US_CMN_INST": 0,
"US_ALU_RGB_ADDR": 1,
"US_ALU_ALPHA_ADDR": 2,
"US_ALU_RGB_INST": 3,
"US_ALU_ALPHA_INST": 4,
"US_ALU_RGBA_INST": 5,
"US_TEX_INST": 1,
"US_TEX_ADDR": 2,
"US_TEX_ADDR_DXDY": 3,
"US_FC_INST": 2,
"US_FC_ADDR": 3,
}
def parse_register(register_name):
base = path.dirname(__file__)
filename = path.join(base, "..", "..", "bits", register_name.lower() + ".txt")
l = list(parse_bits.parse_file_fields(filename))
cls = type(register_name, (BaseRegister,), {})
instance = cls()
descriptors = list(parse_bits.aggregate(l))
code_ix = _descriptor_indicies[register_name]
for descriptor in descriptors:
setattr(instance, descriptor.field_name,
partial(instance.set, code_ix=code_ix, descriptor=descriptor))
func = getattr(instance, descriptor.field_name)
for pv_value, (pv_name, _) in descriptor.possible_values.items():
if pv_name is not None:
setattr(func, pv_name, pv_value)
assert getattr(instance, "descriptors", None) is None
instance.descriptors = descriptors
return instance
US_CMN_INST = parse_register("US_CMN_INST")
US_ALU_RGB_ADDR = parse_register("US_ALU_RGB_ADDR")
US_ALU_ALPHA_ADDR = parse_register("US_ALU_ALPHA_ADDR")
US_ALU_RGB_INST = parse_register("US_ALU_RGB_INST")
US_ALU_ALPHA_INST = parse_register("US_ALU_ALPHA_INST")
US_ALU_RGBA_INST = parse_register("US_ALU_RGBA_INST")
US_TEX_INST = parse_register("US_TEX_INST")
US_TEX_ADDR = parse_register("US_TEX_ADDR")
US_TEX_ADDR_DXDY = parse_register("US_TEX_ADDR_DXDY")
US_FC_INST = parse_register("US_FC_INST")
US_FC_ADDR = parse_register("US_FC_ADDR")
def emit_alpha_op(code, alpha_op):
# dest
US_CMN_INST.ALPHA_WMASK(code, alpha_op.dest.wmask.value)
US_CMN_INST.ALPHA_OMASK(code, alpha_op.dest.omask.value)
# opcode
US_ALU_ALPHA_INST.ALPHA_OP(code, alpha_op.opcode.value)
# sels
srcs = [
US_ALU_ALPHA_INST.ALPHA_SEL_A,
US_ALU_ALPHA_INST.ALPHA_SEL_B,
US_ALU_RGBA_INST.ALPHA_SEL_C,
]
swizzles = [
[US_ALU_ALPHA_INST.ALPHA_SWIZ_A],
[US_ALU_ALPHA_INST.ALPHA_SWIZ_B],
[US_ALU_RGBA_INST.ALPHA_SWIZ_C],
]
mods = [
US_ALU_ALPHA_INST.ALPHA_MOD_A,
US_ALU_ALPHA_INST.ALPHA_MOD_B,
US_ALU_RGBA_INST.ALPHA_MOD_C,
]
for sel, src_func, swizzle_funcs, mod_func in zip(alpha_op.sels,
srcs, swizzles, mods):
src_func(code, sel.src.value)
assert len(sel.swizzle) == 1
assert len(swizzle_funcs) == 1
for swizzle_func, swizzle in zip(swizzle_funcs, sel.swizzle):
swizzle_func(code, swizzle.value)
mod_func(code, sel.mod.value)
def emit_rgb_op(code, rgb_op):
# dest
US_CMN_INST.RGB_WMASK(code, rgb_op.dest.wmask.value)
US_CMN_INST.RGB_OMASK(code, rgb_op.dest.omask.value)
# opcode
US_ALU_RGBA_INST.RGB_OP(code, rgb_op.opcode.value)
# sels
srcs = [
US_ALU_RGB_INST.RGB_SEL_A,
US_ALU_RGB_INST.RGB_SEL_B,
US_ALU_RGBA_INST.RGB_SEL_C,
]
swizzles = [
[US_ALU_RGB_INST.RED_SWIZ_A, US_ALU_RGB_INST.GREEN_SWIZ_A, US_ALU_RGB_INST.BLUE_SWIZ_A],
[US_ALU_RGB_INST.RED_SWIZ_B, US_ALU_RGB_INST.GREEN_SWIZ_B, US_ALU_RGB_INST.BLUE_SWIZ_B],
[US_ALU_RGBA_INST.RED_SWIZ_C, US_ALU_RGBA_INST.GREEN_SWIZ_C, US_ALU_RGBA_INST.BLUE_SWIZ_C],
]
mods = [
US_ALU_RGB_INST.RGB_MOD_A,
US_ALU_RGB_INST.RGB_MOD_B,
US_ALU_RGBA_INST.RGB_MOD_C,
]
for sel, src_func, swizzle_funcs, mod_func in zip(rgb_op.sels,
srcs, swizzles, mods):
src_func(code, sel.src.value)
assert len(sel.swizzle) == 3
assert len(swizzle_funcs) == 3
for swizzle_func, swizzle in zip(swizzle_funcs, sel.swizzle):
swizzle_func(code, swizzle.value)
mod_func(code, sel.mod)
def emit_addr(code, addr):
if addr.alpha.src0 is not None:
is_const = int(addr.alpha.src0.type is SrcAddrType.const)
is_float = int(addr.alpha.src0.type is SrcAddrType.float)
US_ALU_ALPHA_ADDR.ADDR0(code, (is_float << 7) | addr.alpha.src0.value)
US_ALU_ALPHA_ADDR.ADDR0_CONST(code, is_const)
if addr.alpha.src1 is not None:
is_const = int(addr.alpha.src1.type is SrcAddrType.const)
is_float = int(addr.alpha.src1.type is SrcAddrType.float)
US_ALU_ALPHA_ADDR.ADDR1(code, (is_float << 7) | addr.alpha.src1.value)
US_ALU_ALPHA_ADDR.ADDR1_CONST(code, is_const)
if addr.alpha.src2 is not None:
is_const = int(addr.alpha.src2.type is SrcAddrType.const)
is_float = int(addr.alpha.src2.type is SrcAddrType.float)
US_ALU_ALPHA_ADDR.ADDR2(code, (is_float << 7) | addr.alpha.src2.value)
US_ALU_ALPHA_ADDR.ADDR2_CONST(code, is_const)
if addr.alpha.srcp is not None:
US_ALU_ALPHA_ADDR.SRCP_OP(code, addr.alpha.srcp.value)
if addr.rgb.src0 is not None:
is_const = int(addr.rgb.src0.type is SrcAddrType.const)
is_float = int(addr.rgb.src0.type is SrcAddrType.float)
US_ALU_RGB_ADDR.ADDR0(code, (is_float << 7) | addr.rgb.src0.value)
US_ALU_RGB_ADDR.ADDR0_CONST(code, is_const)
if addr.rgb.src1 is not None:
is_const = int(addr.rgb.src1.type is SrcAddrType.const)
is_float = int(addr.rgb.src1.type is SrcAddrType.float)
US_ALU_RGB_ADDR.ADDR1(code, (is_float << 7) | addr.rgb.src1.value)
US_ALU_RGB_ADDR.ADDR1_CONST(code, is_const)
if addr.rgb.src2 is not None:
is_const = int(addr.rgb.src2.type is SrcAddrType.const)
is_float = int(addr.rgb.src2.type is SrcAddrType.float)
US_ALU_RGB_ADDR.ADDR2(code, (is_float << 7) | addr.rgb.src2.value)
US_ALU_RGB_ADDR.ADDR2_CONST(code, is_const)
if addr.rgb.srcp is not None:
US_ALU_RGB_ADDR.SRCP_OP(code, addr.rgb.srcp.value)
def emit_instruction(code, ins):
emit_addr(code, ins.addr)
emit_alpha_op(code, ins.alpha_op)
emit_rgb_op(code, ins.rgb_op)

View File

@ -181,6 +181,10 @@ class Parser(BaseParser):
operations, operations,
) )
def instructions(self):
while not self.match(TT.eof):
yield self.instruction()
if __name__ == "__main__": if __name__ == "__main__":
from assembler.lexer import Lexer from assembler.lexer import Lexer
buf = b""" buf = b"""

View File

@ -181,14 +181,26 @@ def validate_instruction_let_expressions(let_expressions):
(KW.FLOAT, SrcAddrType.float), (KW.FLOAT, SrcAddrType.float),
]) ])
src_addr_type_strs = keywords_to_string(keyword_to_src_addr_type.keys()) src_addr_type_strs = keywords_to_string(keyword_to_src_addr_type.keys())
type = expr.addr_keyword.keyword type_kw = expr.addr_keyword.keyword
if type not in keyword_to_src_addr_type: if type_kw not in keyword_to_src_addr_type:
raise ValidatorError(f"invalid src addr type, expected one of {src_addr_type_strs}", expr.addr_keyword) raise ValidatorError(f"invalid src addr type, expected one of {src_addr_type_strs}", expr.addr_keyword)
type = keyword_to_src_addr_type[type_kw]
value = validate_identifier_number(expr.addr_value_identifier) value = validate_identifier_number(expr.addr_value_identifier)
if type is SrcAddrType.float:
if value >= 128:
raise ValidatorError(f"invalid float value", expr.addr_value_identifier)
elif type is SrcAddrType.temp:
if value >= 128:
raise ValidatorError(f"invalid temp value", expr.addr_value_identifier)
elif type is SrcAddrType.const:
if value >= 256:
raise ValidatorError(f"invalid const value", expr.addr_value_identifier)
else:
assert False, (id(type), id(SrcAddrType.float))
return SrcAddr( return SrcAddr(
keyword_to_src_addr_type[type], type,
value, value,
) )
elif src == KW.SRCP: elif src == KW.SRCP:
@ -403,7 +415,7 @@ swizzle_kws = OrderedDict([
(ord("_"), Swizzle.unused), (ord("_"), Swizzle.unused),
]) ])
def validate_instruction_operation_sels(swizzle_sels): def validate_instruction_operation_sels(swizzle_sels, is_alpha):
if len(swizzle_sels) > 3: if len(swizzle_sels) > 3:
raise ValidatorError("too many swizzle sels", swizzle_sels[-1].sel_keyword) raise ValidatorError("too many swizzle sels", swizzle_sels[-1].sel_keyword)
@ -414,10 +426,11 @@ def validate_instruction_operation_sels(swizzle_sels):
src = swizzle_sel_src_kws[swizzle_sel.sel_keyword.keyword] src = swizzle_sel_src_kws[swizzle_sel.sel_keyword.keyword]
swizzle_lexeme = swizzle_sel.swizzle_identifier.lexeme.lower() swizzle_lexeme = swizzle_sel.swizzle_identifier.lexeme.lower()
if len(swizzle_lexeme) > 4: swizzles_length = 1 if is_alpha else 3
raise ValidatorError("invalid swizzle", swizzle_sel.swizzle_identifier) if len(swizzle_lexeme) != swizzles_length:
raise ValidatorError("invalid swizzle length", swizzle_sel.swizzle_identifier)
if not all(c in swizzle_kws for c in swizzle_lexeme): if not all(c in swizzle_kws for c in swizzle_lexeme):
raise ValidatorError("invalid swizzle", swizzle_sel.swizzle_identifier) raise ValidatorError("invalid swizzle characters", swizzle_sel.swizzle_identifier)
swizzle = [ swizzle = [
swizzle_kws[c] for c in swizzle_lexeme swizzle_kws[c] for c in swizzle_lexeme
] ]
@ -431,7 +444,7 @@ def validate_alpha_instruction_operation(operation):
mask_lookup=alpha_masks, mask_lookup=alpha_masks,
type_cls=AlphaDest) type_cls=AlphaDest)
opcode = alpha_op_kws[operation.opcode_keyword.keyword] opcode = alpha_op_kws[operation.opcode_keyword.keyword]
sels = validate_instruction_operation_sels(operation.swizzle_sels) sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=True)
return AlphaOperation( return AlphaOperation(
dest, dest,
opcode, opcode,
@ -443,7 +456,7 @@ def validate_rgb_instruction_operation(operation):
mask_lookup=rgb_masks, mask_lookup=rgb_masks,
type_cls=RGBDest) type_cls=RGBDest)
opcode = rgb_op_kws[operation.opcode_keyword.keyword] opcode = rgb_op_kws[operation.opcode_keyword.keyword]
sels = validate_instruction_operation_sels(operation.swizzle_sels) sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=False)
return RGBOperation( return RGBOperation(
dest, dest,
opcode, opcode,
@ -472,9 +485,10 @@ def validate_instruction(ins):
return instruction return instruction
if __name__ == "__main__": if __name__ == "__main__":
from assembler.lexer import Lexer, LexerError
from assembler.fs.parser import Parser, ParserError from assembler.fs.parser import Parser, ParserError
from assembler.fs.keywords import find_keyword from assembler.fs.keywords import find_keyword
from assembler.lexer import Lexer
buf = b""" buf = b"""
src0.a = float(0), src0.rgb = temp[0] , srcp.a = neg : src0.a = float(0), src0.rgb = temp[0] , srcp.a = neg :
out[0].none = temp[0].none = MAD src0.r src0.r src0.r , out[0].none = temp[0].none = MAD src0.r src0.r src0.r ,
@ -484,8 +498,11 @@ src0.a = float(0), src0.rgb = temp[0] , srcp.a = neg :
tokens = list(lexer.lex_tokens()) tokens = list(lexer.lex_tokens())
parser = Parser(tokens) parser = Parser(tokens)
try: try:
ins = parser.instruction() ins_ast = parser.instruction()
pprint(validate_instruction(ins)) pprint(validate_instruction(ins_ast))
except LexerError as e:
print_error(None, buf, e)
raise
except ParserError as e: except ParserError as e:
print_error(None, buf, e) print_error(None, buf, e)
raise raise

View File

@ -5,6 +5,7 @@ from assembler.vs.keywords import find_keyword
from assembler.vs.parser import Parser, ParserError from assembler.vs.parser import Parser, ParserError
from assembler.vs.emitter import emit_instruction from assembler.vs.emitter import emit_instruction
from assembler.vs.validator import validate_instruction from assembler.vs.validator import validate_instruction
from assembler.error import print_error
sample = b""" sample = b"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000

View File

@ -5,7 +5,7 @@ from collections import OrderedDict
from functools import partial from functools import partial
from pprint import pprint from pprint import pprint
VERBOSE = True VERBOSE = False
class BaseRegister: class BaseRegister:
def get(self, code, *, code_ix, descriptor): def get(self, code, *, code_ix, descriptor):
@ -323,10 +323,10 @@ def disassemble_alu(code, is_output):
print(", ".join([*a_addr_strs, *rgb_addr_strs]), ":") print(", ".join([*a_addr_strs, *rgb_addr_strs]), ":")
#print(", ".join(a_addr_strs), ":") #print(", ".join(a_addr_strs), ":")
print(f" {a_out_str} = {a_temp_str} = {a_op.ljust(6)} {' '.join(a_swizzle_sel)}", ",") print(f" {a_out_str} = {a_temp_str} = {a_op.removeprefix('OP_').ljust(3)} {' '.join(a_swizzle_sel)}", ",")
#print(", ".join(rgb_addr_strs), ":") #print(", ".join(rgb_addr_strs), ":")
print(f" {rgb_out_str} = {rgb_temp_str} = {rgb_op.ljust(6)} {' '.join(rgb_swizzle_sel)}", ";") print(f" {rgb_out_str} = {rgb_temp_str} = {rgb_op.removeprefix('OP_').ljust(3)} {' '.join(rgb_swizzle_sel)}", ";")
def disassemble(code): def disassemble(code):
assert len(code) == 6, len(code) assert len(code) == 6, len(code)