From efecb277c84200e6ae92f754247471e6e44b2bfe Mon Sep 17 00:00:00 2001 From: Zack Buhman Date: Mon, 20 Oct 2025 18:21:41 -0500 Subject: [PATCH] assembler: implement fs emitter and frontend --- regs/assembler/fs/__main__.py | 38 +++++++ regs/assembler/fs/emitter.py | 178 +++++++++++++++++++++++++++++++++ regs/assembler/fs/parser.py | 4 + regs/assembler/fs/validator.py | 41 +++++--- regs/assembler/vs/__main__.py | 1 + regs/us_disassemble2.py | 6 +- 6 files changed, 253 insertions(+), 15 deletions(-) create mode 100644 regs/assembler/fs/__main__.py create mode 100644 regs/assembler/fs/emitter.py diff --git a/regs/assembler/fs/__main__.py b/regs/assembler/fs/__main__.py new file mode 100644 index 0000000..60284ae --- /dev/null +++ b/regs/assembler/fs/__main__.py @@ -0,0 +1,38 @@ +import sys + +from assembler.lexer import Lexer, LexerError +from assembler.fs.parser import Parser, ParserError +from assembler.fs.keywords import find_keyword +from assembler.fs.validator import validate_instruction, ValidatorError +from assembler.fs.emitter import emit_instruction +from assembler.error import print_error + +def frontend_inner(buf): + lexer = Lexer(buf, find_keyword, emit_newlines=False) + tokens = list(lexer.lex_tokens()) + parser = Parser(tokens) + for ins_ast in parser.instructions(): + ins = validate_instruction(ins_ast) + code = [0] * 6 + emit_instruction(code, ins) + print("\n".join(f"0x{code[i]:08x}," for i in range(6))) + print() + +def frontend(filename, buf): + try: + frontend_inner(buf) + except LexerError as e: + print_error(filename, buf, e) + raise + except ParserError as e: + print_error(filename, buf, e) + raise + except ValidatorError as e: + print_error(filename, buf, e) + raise + +if __name__ == "__main__": + input_filename = sys.argv[1] + with open(input_filename, 'rb') as f: + buf = f.read() + frontend(input_filename, buf) diff --git a/regs/assembler/fs/emitter.py b/regs/assembler/fs/emitter.py new file mode 100644 index 0000000..47d81a3 --- /dev/null +++ b/regs/assembler/fs/emitter.py @@ -0,0 +1,178 @@ +from os import path +from pprint import pprint +from functools import partial + +import parse_bits +from assembler.fs.validator import SrcAddrType + +class BaseRegister: + def set(self, code, value, *, code_ix, descriptor): + if type(descriptor.bits) is int: + mask = 1 + low = descriptor.bits + else: + high, low = descriptor.bits + assert high > low + mask_length = (high - low) + 1 + mask = (1 << mask_length) - 1 + + code_value = code[code_ix] + assert (code_value >> low) & mask == 0 + assert value & mask == value + code[code_ix] |= (value & mask) << low + +_descriptor_indicies = { + "US_CMN_INST": 0, + "US_ALU_RGB_ADDR": 1, + "US_ALU_ALPHA_ADDR": 2, + "US_ALU_RGB_INST": 3, + "US_ALU_ALPHA_INST": 4, + "US_ALU_RGBA_INST": 5, + + "US_TEX_INST": 1, + "US_TEX_ADDR": 2, + "US_TEX_ADDR_DXDY": 3, + + "US_FC_INST": 2, + "US_FC_ADDR": 3, +} + +def parse_register(register_name): + base = path.dirname(__file__) + + filename = path.join(base, "..", "..", "bits", register_name.lower() + ".txt") + l = list(parse_bits.parse_file_fields(filename)) + cls = type(register_name, (BaseRegister,), {}) + instance = cls() + descriptors = list(parse_bits.aggregate(l)) + code_ix = _descriptor_indicies[register_name] + for descriptor in descriptors: + setattr(instance, descriptor.field_name, + partial(instance.set, code_ix=code_ix, descriptor=descriptor)) + func = getattr(instance, descriptor.field_name) + for pv_value, (pv_name, _) in descriptor.possible_values.items(): + if pv_name is not None: + setattr(func, pv_name, pv_value) + assert getattr(instance, "descriptors", None) is None + instance.descriptors = descriptors + + return instance + +US_CMN_INST = parse_register("US_CMN_INST") +US_ALU_RGB_ADDR = parse_register("US_ALU_RGB_ADDR") +US_ALU_ALPHA_ADDR = parse_register("US_ALU_ALPHA_ADDR") +US_ALU_RGB_INST = parse_register("US_ALU_RGB_INST") +US_ALU_ALPHA_INST = parse_register("US_ALU_ALPHA_INST") +US_ALU_RGBA_INST = parse_register("US_ALU_RGBA_INST") +US_TEX_INST = parse_register("US_TEX_INST") +US_TEX_ADDR = parse_register("US_TEX_ADDR") +US_TEX_ADDR_DXDY = parse_register("US_TEX_ADDR_DXDY") +US_FC_INST = parse_register("US_FC_INST") +US_FC_ADDR = parse_register("US_FC_ADDR") + +def emit_alpha_op(code, alpha_op): + # dest + US_CMN_INST.ALPHA_WMASK(code, alpha_op.dest.wmask.value) + US_CMN_INST.ALPHA_OMASK(code, alpha_op.dest.omask.value) + + # opcode + US_ALU_ALPHA_INST.ALPHA_OP(code, alpha_op.opcode.value) + + # sels + srcs = [ + US_ALU_ALPHA_INST.ALPHA_SEL_A, + US_ALU_ALPHA_INST.ALPHA_SEL_B, + US_ALU_RGBA_INST.ALPHA_SEL_C, + ] + swizzles = [ + [US_ALU_ALPHA_INST.ALPHA_SWIZ_A], + [US_ALU_ALPHA_INST.ALPHA_SWIZ_B], + [US_ALU_RGBA_INST.ALPHA_SWIZ_C], + ] + mods = [ + US_ALU_ALPHA_INST.ALPHA_MOD_A, + US_ALU_ALPHA_INST.ALPHA_MOD_B, + US_ALU_RGBA_INST.ALPHA_MOD_C, + ] + for sel, src_func, swizzle_funcs, mod_func in zip(alpha_op.sels, + srcs, swizzles, mods): + src_func(code, sel.src.value) + assert len(sel.swizzle) == 1 + assert len(swizzle_funcs) == 1 + for swizzle_func, swizzle in zip(swizzle_funcs, sel.swizzle): + swizzle_func(code, swizzle.value) + mod_func(code, sel.mod.value) + +def emit_rgb_op(code, rgb_op): + # dest + US_CMN_INST.RGB_WMASK(code, rgb_op.dest.wmask.value) + US_CMN_INST.RGB_OMASK(code, rgb_op.dest.omask.value) + + # opcode + US_ALU_RGBA_INST.RGB_OP(code, rgb_op.opcode.value) + + # sels + srcs = [ + US_ALU_RGB_INST.RGB_SEL_A, + US_ALU_RGB_INST.RGB_SEL_B, + US_ALU_RGBA_INST.RGB_SEL_C, + ] + swizzles = [ + [US_ALU_RGB_INST.RED_SWIZ_A, US_ALU_RGB_INST.GREEN_SWIZ_A, US_ALU_RGB_INST.BLUE_SWIZ_A], + [US_ALU_RGB_INST.RED_SWIZ_B, US_ALU_RGB_INST.GREEN_SWIZ_B, US_ALU_RGB_INST.BLUE_SWIZ_B], + [US_ALU_RGBA_INST.RED_SWIZ_C, US_ALU_RGBA_INST.GREEN_SWIZ_C, US_ALU_RGBA_INST.BLUE_SWIZ_C], + ] + mods = [ + US_ALU_RGB_INST.RGB_MOD_A, + US_ALU_RGB_INST.RGB_MOD_B, + US_ALU_RGBA_INST.RGB_MOD_C, + ] + for sel, src_func, swizzle_funcs, mod_func in zip(rgb_op.sels, + srcs, swizzles, mods): + src_func(code, sel.src.value) + assert len(sel.swizzle) == 3 + assert len(swizzle_funcs) == 3 + for swizzle_func, swizzle in zip(swizzle_funcs, sel.swizzle): + swizzle_func(code, swizzle.value) + mod_func(code, sel.mod) + +def emit_addr(code, addr): + if addr.alpha.src0 is not None: + is_const = int(addr.alpha.src0.type is SrcAddrType.const) + is_float = int(addr.alpha.src0.type is SrcAddrType.float) + US_ALU_ALPHA_ADDR.ADDR0(code, (is_float << 7) | addr.alpha.src0.value) + US_ALU_ALPHA_ADDR.ADDR0_CONST(code, is_const) + if addr.alpha.src1 is not None: + is_const = int(addr.alpha.src1.type is SrcAddrType.const) + is_float = int(addr.alpha.src1.type is SrcAddrType.float) + US_ALU_ALPHA_ADDR.ADDR1(code, (is_float << 7) | addr.alpha.src1.value) + US_ALU_ALPHA_ADDR.ADDR1_CONST(code, is_const) + if addr.alpha.src2 is not None: + is_const = int(addr.alpha.src2.type is SrcAddrType.const) + is_float = int(addr.alpha.src2.type is SrcAddrType.float) + US_ALU_ALPHA_ADDR.ADDR2(code, (is_float << 7) | addr.alpha.src2.value) + US_ALU_ALPHA_ADDR.ADDR2_CONST(code, is_const) + if addr.alpha.srcp is not None: + US_ALU_ALPHA_ADDR.SRCP_OP(code, addr.alpha.srcp.value) + if addr.rgb.src0 is not None: + is_const = int(addr.rgb.src0.type is SrcAddrType.const) + is_float = int(addr.rgb.src0.type is SrcAddrType.float) + US_ALU_RGB_ADDR.ADDR0(code, (is_float << 7) | addr.rgb.src0.value) + US_ALU_RGB_ADDR.ADDR0_CONST(code, is_const) + if addr.rgb.src1 is not None: + is_const = int(addr.rgb.src1.type is SrcAddrType.const) + is_float = int(addr.rgb.src1.type is SrcAddrType.float) + US_ALU_RGB_ADDR.ADDR1(code, (is_float << 7) | addr.rgb.src1.value) + US_ALU_RGB_ADDR.ADDR1_CONST(code, is_const) + if addr.rgb.src2 is not None: + is_const = int(addr.rgb.src2.type is SrcAddrType.const) + is_float = int(addr.rgb.src2.type is SrcAddrType.float) + US_ALU_RGB_ADDR.ADDR2(code, (is_float << 7) | addr.rgb.src2.value) + US_ALU_RGB_ADDR.ADDR2_CONST(code, is_const) + if addr.rgb.srcp is not None: + US_ALU_RGB_ADDR.SRCP_OP(code, addr.rgb.srcp.value) + +def emit_instruction(code, ins): + emit_addr(code, ins.addr) + emit_alpha_op(code, ins.alpha_op) + emit_rgb_op(code, ins.rgb_op) diff --git a/regs/assembler/fs/parser.py b/regs/assembler/fs/parser.py index 6de8405..7532785 100644 --- a/regs/assembler/fs/parser.py +++ b/regs/assembler/fs/parser.py @@ -181,6 +181,10 @@ class Parser(BaseParser): operations, ) + def instructions(self): + while not self.match(TT.eof): + yield self.instruction() + if __name__ == "__main__": from assembler.lexer import Lexer buf = b""" diff --git a/regs/assembler/fs/validator.py b/regs/assembler/fs/validator.py index f71c6a5..646a950 100644 --- a/regs/assembler/fs/validator.py +++ b/regs/assembler/fs/validator.py @@ -181,14 +181,26 @@ def validate_instruction_let_expressions(let_expressions): (KW.FLOAT, SrcAddrType.float), ]) src_addr_type_strs = keywords_to_string(keyword_to_src_addr_type.keys()) - type = expr.addr_keyword.keyword - if type not in keyword_to_src_addr_type: + type_kw = expr.addr_keyword.keyword + if type_kw not in keyword_to_src_addr_type: raise ValidatorError(f"invalid src addr type, expected one of {src_addr_type_strs}", expr.addr_keyword) + type = keyword_to_src_addr_type[type_kw] value = validate_identifier_number(expr.addr_value_identifier) + if type is SrcAddrType.float: + if value >= 128: + raise ValidatorError(f"invalid float value", expr.addr_value_identifier) + elif type is SrcAddrType.temp: + if value >= 128: + raise ValidatorError(f"invalid temp value", expr.addr_value_identifier) + elif type is SrcAddrType.const: + if value >= 256: + raise ValidatorError(f"invalid const value", expr.addr_value_identifier) + else: + assert False, (id(type), id(SrcAddrType.float)) return SrcAddr( - keyword_to_src_addr_type[type], + type, value, ) elif src == KW.SRCP: @@ -403,7 +415,7 @@ swizzle_kws = OrderedDict([ (ord("_"), Swizzle.unused), ]) -def validate_instruction_operation_sels(swizzle_sels): +def validate_instruction_operation_sels(swizzle_sels, is_alpha): if len(swizzle_sels) > 3: raise ValidatorError("too many swizzle sels", swizzle_sels[-1].sel_keyword) @@ -414,10 +426,11 @@ def validate_instruction_operation_sels(swizzle_sels): src = swizzle_sel_src_kws[swizzle_sel.sel_keyword.keyword] swizzle_lexeme = swizzle_sel.swizzle_identifier.lexeme.lower() - if len(swizzle_lexeme) > 4: - raise ValidatorError("invalid swizzle", swizzle_sel.swizzle_identifier) + swizzles_length = 1 if is_alpha else 3 + if len(swizzle_lexeme) != swizzles_length: + raise ValidatorError("invalid swizzle length", swizzle_sel.swizzle_identifier) if not all(c in swizzle_kws for c in swizzle_lexeme): - raise ValidatorError("invalid swizzle", swizzle_sel.swizzle_identifier) + raise ValidatorError("invalid swizzle characters", swizzle_sel.swizzle_identifier) swizzle = [ swizzle_kws[c] for c in swizzle_lexeme ] @@ -431,7 +444,7 @@ def validate_alpha_instruction_operation(operation): mask_lookup=alpha_masks, type_cls=AlphaDest) opcode = alpha_op_kws[operation.opcode_keyword.keyword] - sels = validate_instruction_operation_sels(operation.swizzle_sels) + sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=True) return AlphaOperation( dest, opcode, @@ -443,7 +456,7 @@ def validate_rgb_instruction_operation(operation): mask_lookup=rgb_masks, type_cls=RGBDest) opcode = rgb_op_kws[operation.opcode_keyword.keyword] - sels = validate_instruction_operation_sels(operation.swizzle_sels) + sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=False) return RGBOperation( dest, opcode, @@ -472,9 +485,10 @@ def validate_instruction(ins): return instruction if __name__ == "__main__": + from assembler.lexer import Lexer, LexerError from assembler.fs.parser import Parser, ParserError from assembler.fs.keywords import find_keyword - from assembler.lexer import Lexer + buf = b""" src0.a = float(0), src0.rgb = temp[0] , srcp.a = neg : out[0].none = temp[0].none = MAD src0.r src0.r src0.r , @@ -484,8 +498,11 @@ src0.a = float(0), src0.rgb = temp[0] , srcp.a = neg : tokens = list(lexer.lex_tokens()) parser = Parser(tokens) try: - ins = parser.instruction() - pprint(validate_instruction(ins)) + ins_ast = parser.instruction() + pprint(validate_instruction(ins_ast)) + except LexerError as e: + print_error(None, buf, e) + raise except ParserError as e: print_error(None, buf, e) raise diff --git a/regs/assembler/vs/__main__.py b/regs/assembler/vs/__main__.py index fb0bb99..4d5dab5 100644 --- a/regs/assembler/vs/__main__.py +++ b/regs/assembler/vs/__main__.py @@ -5,6 +5,7 @@ from assembler.vs.keywords import find_keyword from assembler.vs.parser import Parser, ParserError from assembler.vs.emitter import emit_instruction from assembler.vs.validator import validate_instruction +from assembler.error import print_error sample = b""" temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 diff --git a/regs/us_disassemble2.py b/regs/us_disassemble2.py index 6d8482a..940c84e 100644 --- a/regs/us_disassemble2.py +++ b/regs/us_disassemble2.py @@ -5,7 +5,7 @@ from collections import OrderedDict from functools import partial from pprint import pprint -VERBOSE = True +VERBOSE = False class BaseRegister: def get(self, code, *, code_ix, descriptor): @@ -323,10 +323,10 @@ def disassemble_alu(code, is_output): print(", ".join([*a_addr_strs, *rgb_addr_strs]), ":") #print(", ".join(a_addr_strs), ":") - print(f" {a_out_str} = {a_temp_str} = {a_op.ljust(6)} {' '.join(a_swizzle_sel)}", ",") + print(f" {a_out_str} = {a_temp_str} = {a_op.removeprefix('OP_').ljust(3)} {' '.join(a_swizzle_sel)}", ",") #print(", ".join(rgb_addr_strs), ":") - print(f" {rgb_out_str} = {rgb_temp_str} = {rgb_op.ljust(6)} {' '.join(rgb_swizzle_sel)}", ";") + print(f" {rgb_out_str} = {rgb_temp_str} = {rgb_op.removeprefix('OP_').ljust(3)} {' '.join(rgb_swizzle_sel)}", ";") def disassemble(code): assert len(code) == 6, len(code)