diff --git a/regs/assembler/fs/__main__.py b/regs/assembler/fs/__main__.py index 048fb9f..9dd9519 100644 --- a/regs/assembler/fs/__main__.py +++ b/regs/assembler/fs/__main__.py @@ -1,9 +1,11 @@ import sys from assembler.lexer import Lexer, LexerError -from assembler.fs.parser import Parser, ParserError +from assembler.parser import ParserError +from assembler.validator import ValidatorError +from assembler.fs.parser import Parser from assembler.fs.keywords import find_keyword -from assembler.fs.validator import validate_instruction, ValidatorError +from assembler.fs.validator import validate_instruction from assembler.fs.emitter import emit_instruction from assembler.error import print_error diff --git a/regs/assembler/fs/parser.py b/regs/assembler/fs/parser.py index 49a2745..13027e5 100644 --- a/regs/assembler/fs/parser.py +++ b/regs/assembler/fs/parser.py @@ -124,7 +124,7 @@ class Parser(BaseParser): swizzle_identifier = self.consume(TT.identifier, "expected swizzle identifier") if abs: - self.consume(TT.bar, "expected bar") + self.consume(TT.bar, "expected vertical bar") mod_table = { # (neg, abs) diff --git a/regs/assembler/fs/validator.py b/regs/assembler/fs/validator.py index cf8690e..fdcd18e 100644 --- a/regs/assembler/fs/validator.py +++ b/regs/assembler/fs/validator.py @@ -6,9 +6,7 @@ from collections import OrderedDict from assembler.fs.parser import Mod from assembler.fs.keywords import _keyword_to_string, KW from assembler.error import print_error - -class ValidatorError(Exception): - pass +from assembler.validator import ValidatorError class SrcAddrType(Enum): temp = auto() diff --git a/regs/assembler/validator.py b/regs/assembler/validator.py new file mode 100644 index 0000000..79ce89a --- /dev/null +++ b/regs/assembler/validator.py @@ -0,0 +1,2 @@ +class ValidatorError(Exception): + pass diff --git a/regs/assembler/vs/__main__.py b/regs/assembler/vs/__main__.py index 4d5dab5..3847b51 100644 --- a/regs/assembler/vs/__main__.py +++ b/regs/assembler/vs/__main__.py @@ -24,9 +24,9 @@ def frontend_inner(buf): lexer = Lexer(buf, find_keyword) tokens = list(lexer.lex_tokens()) parser = Parser(tokens) - for ins, start_end in parser.instructions(): + for ins in parser.instructions(): ins = validate_instruction(ins) - yield list(emit_instruction(ins)), start_end + yield list(emit_instruction(ins)) def frontend(filename, buf): try: @@ -43,10 +43,5 @@ if __name__ == "__main__": #output_filename = sys.argv[2] with open(input_filename, 'rb') as f: buf = f.read() - output = list(frontend(input_filename, buf)) - for cw, (start_ix, end_ix) in output: - if True: - print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},") - else: - source = buf[start_ix:end_ix] - print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x}, // {source.decode('utf-8')}") + for cw in frontend(input_filename, buf): + print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},") diff --git a/regs/assembler/vs/emitter.py b/regs/assembler/vs/emitter.py index b9598d5..2be1ebf 100644 --- a/regs/assembler/vs/emitter.py +++ b/regs/assembler/vs/emitter.py @@ -1,92 +1,59 @@ -from assembler.vs.keywords import ME, VE, MVE, KW -from assembler.vs.parser import Instruction, DestinationOp, Source +from typing import Union + +from assembler.vs.opcodes import ME, VE, MVE +from assembler.vs.validator import Destination, Source, Instruction + import pvs_dst import pvs_src -import pvs_dst_bits -import pvs_src_bits -def we_x(s): - return int(0 in s) +def emit_destination_opcode(destination: Destination, + opcode: Union[ME, VE, MVE], + saturation: bool): + assert type(opcode) in {ME, VE, MVE} + math_inst = int(type(opcode) is ME) + macro_inst = int(type(opcode) is MVE) -def we_y(s): - return int(1 in s) - -def we_z(s): - return int(2 in s) - -def we_w(s): - return int(3 in s) - -def dst_reg_type(kw): - if kw == KW.temporary: - return pvs_dst_bits.PVS_DST_REG_gen["TEMPORARY"] - elif kw == KW.a0: - return pvs_dst_bits.PVS_DST_REG_gen["A0"] - elif kw == KW.out: - return pvs_dst_bits.PVS_DST_REG_gen["OUT"] - elif kw == KW.out_repl_x: - return pvs_dst_bits.PVS_DST_REG_gen["OUT_REPL_X"] - elif kw == KW.alt_temporary: - return pvs_dst_bits.PVS_DST_REG_gen["ALT_TEMPORARY"] - elif kw == KW.input: - return pvs_dst_bits.PVS_DST_REG_gen["INPUT"] - else: - assert not "Invalid PVS_DST_REG", kw - -def emit_destination_op(dst_op: DestinationOp): - assert type(dst_op.opcode) in {ME, VE, MVE} - math_inst = int(type(dst_op.opcode) is ME) - if dst_op.macro: - assert dst_op.opcode.value in {0, 1} - ve_sat = int((not math_inst) and dst_op.sat) - me_sat = int(math_inst and dst_op.sat) + ve_sat = int((not math_inst) and saturation) + me_sat = int(math_inst and saturation) value = ( - pvs_dst.OPCODE_gen(dst_op.opcode.value) + pvs_dst.OPCODE_gen(opcode.value) | pvs_dst.MATH_INST_gen(math_inst) - | pvs_dst.MACRO_INST_gen(int(dst_op.macro)) - | pvs_dst.REG_TYPE_gen(dst_reg_type(dst_op.type)) - | pvs_dst.OFFSET_gen(dst_op.offset) - | pvs_dst.WE_X_gen(we_x(dst_op.write_enable)) - | pvs_dst.WE_Y_gen(we_y(dst_op.write_enable)) - | pvs_dst.WE_Z_gen(we_z(dst_op.write_enable)) - | pvs_dst.WE_W_gen(we_w(dst_op.write_enable)) + | pvs_dst.MACRO_INST_gen(int(macro_inst)) + | pvs_dst.REG_TYPE_gen(destination.type.value) + | pvs_dst.OFFSET_gen(destination.offset) + | pvs_dst.WE_X_gen(int(destination.write_enable[0])) + | pvs_dst.WE_Y_gen(int(destination.write_enable[1])) + | pvs_dst.WE_Z_gen(int(destination.write_enable[2])) + | pvs_dst.WE_W_gen(int(destination.write_enable[3])) | pvs_dst.VE_SAT_gen(ve_sat) | pvs_dst.ME_SAT_gen(me_sat) ) yield value -def src_reg_type(kw): - if kw == KW.temporary: - return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_TEMPORARY"] - elif kw == KW.input: - return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_INPUT"] - elif kw == KW.constant: - return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_CONSTANT"] - elif kw == KW.alt_temporary: - return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_ALT_TEMPORARY"] - else: - assert not "Invalid PVS_SRC_REG", kw - def emit_source(src: Source, prev: Source): if src is not None: + assert src.offset <= 255 value = ( - pvs_src.REG_TYPE_gen(src_reg_type(src.type)) + pvs_src.REG_TYPE_gen(src.type.value) | pvs_src.OFFSET_gen(src.offset) - | pvs_src.SWIZZLE_X_gen(src.swizzle.select[0]) - | pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1]) - | pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2]) - | pvs_src.SWIZZLE_W_gen(src.swizzle.select[3]) - | pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0])) - | pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1])) - | pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2])) - | pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3])) + | pvs_src.ABS_XYZW_gen(int(src.absolute)) + | pvs_src.SWIZZLE_X_gen(src.swizzle_selects[0].value) + | pvs_src.SWIZZLE_Y_gen(src.swizzle_selects[1].value) + | pvs_src.SWIZZLE_Z_gen(src.swizzle_selects[2].value) + | pvs_src.SWIZZLE_W_gen(src.swizzle_selects[3].value) + | pvs_src.MODIFIER_X_gen(int(src.modifiers[0])) + | pvs_src.MODIFIER_Y_gen(int(src.modifiers[1])) + | pvs_src.MODIFIER_Z_gen(int(src.modifiers[2])) + | pvs_src.MODIFIER_W_gen(int(src.modifiers[3])) ) else: assert prev is not None + assert prev.offset <= 255 value = ( - pvs_src.REG_TYPE_gen(src_reg_type(prev.type)) + pvs_src.REG_TYPE_gen(prev.type.value) | pvs_src.OFFSET_gen(prev.offset) + | pvs_src.ABS_XYZW_gen(0) | pvs_src.SWIZZLE_X_gen(7) | pvs_src.SWIZZLE_Y_gen(7) | pvs_src.SWIZZLE_Z_gen(7) @@ -99,22 +66,32 @@ def emit_source(src: Source, prev: Source): yield value def prev_source(ins, ix): + assert ins.sources[0] is not None if ix == 0: - assert ins.source0 is not None - return ins.source0 + return ins.sources[0] elif ix == 1: - return ins.source0 + return ins.sources[0] elif ix == 2: - if ins.source1 is not None: - return ins.source1 + if ins.sources[1] is not None: + return ins.sources[1] else: - return ins.source0 + return ins.sources[0] else: assert False, ix def emit_instruction(ins: Instruction): - yield from emit_destination_op(ins.destination_op) + yield from emit_destination_opcode(ins.destination, + ins.opcode, + ins.saturation) - yield from emit_source(ins.source0, prev_source(ins, 0)) - yield from emit_source(ins.source1, prev_source(ins, 1)) - yield from emit_source(ins.source2, prev_source(ins, 2)) + if len(ins.sources) >= 1: + yield from emit_source(ins.sources[0], prev_source(ins, 0)) + + source1 = ins.sources[1] if len(ins.sources) >= 2 else None + source2 = ins.sources[2] if len(ins.sources) >= 3 else None + yield from emit_source(source1, prev_source(ins, 1)) + yield from emit_source(source2, prev_source(ins, 2)) + else: + yield 0 + yield 0 + yield 0 diff --git a/regs/assembler/vs/keywords.py b/regs/assembler/vs/keywords.py index 3b3ae46..b50d05f 100644 --- a/regs/assembler/vs/keywords.py +++ b/regs/assembler/vs/keywords.py @@ -2,93 +2,68 @@ from dataclasses import dataclass from typing import Optional from enum import Enum, auto -@dataclass -class MVE: - name: str - synonym: Optional[str] - value: int +from assembler.vs import opcodes -@dataclass -class VE: - name: str - synonym: Optional[str] - value: int +operations = [ + opcodes.VECTOR_NO_OP, + opcodes.VE_DOT_PRODUCT, + opcodes.VE_MULTIPLY, + opcodes.VE_ADD, + opcodes.VE_MULTIPLY_ADD, + opcodes.VE_DISTANCE_VECTOR, + opcodes.VE_FRACTION, + opcodes.VE_MAXIMUM, + opcodes.VE_MINIMUM, + opcodes.VE_SET_GREATER_THAN_EQUAL, + opcodes.VE_SET_LESS_THAN, + opcodes.VE_MULTIPLYX2_ADD, + opcodes.VE_MULTIPLY_CLAMP, + opcodes.VE_FLT2FIX_DX, + opcodes.VE_FLT2FIX_DX_RND, + opcodes.VE_PRED_SET_EQ_PUSH, + opcodes.VE_PRED_SET_GT_PUSH, + opcodes.VE_PRED_SET_GTE_PUSH, + opcodes.VE_PRED_SET_NEQ_PUSH, + opcodes.VE_COND_WRITE_EQ, + opcodes.VE_COND_WRITE_GT, + opcodes.VE_COND_WRITE_GTE, + opcodes.VE_COND_WRITE_NEQ, + opcodes.VE_COND_MUX_EQ, + opcodes.VE_COND_MUX_GT, + opcodes.VE_COND_MUX_GTE, + opcodes.VE_SET_GREATER_THAN, + opcodes.VE_SET_EQUAL, + opcodes.VE_SET_NOT_EQUAL, -@dataclass -class ME: - name: str - synonym: Optional[str] - value: int - -macro_vector_operations = [ - MVE(b"MACRO_OP_2CLK_MADD" , None , 0), - MVE(b"MACRO_OP_2CLK_M2X_ADD" , None , 1), -] - -vector_operations = [ - # name synonym value - VE(b"VECTOR_NO_OP" , b"VE_NOP" , 0), - VE(b"VE_DOT_PRODUCT" , b"VE_DOT" , 1), - VE(b"VE_MULTIPLY" , b"VE_MUL" , 2), - VE(b"VE_ADD" , None , 3), - VE(b"VE_MULTIPLY_ADD" , b"VE_MAD" , 4), - VE(b"VE_DISTANCE_VECTOR" , None , 5), - VE(b"VE_FRACTION" , b"VE_FRC" , 6), - VE(b"VE_MAXIMUM" , b"VE_MAX" , 7), - VE(b"VE_MINIMUM" , b"VE_MIN" , 8), - VE(b"VE_SET_GREATER_THAN_EQUAL" , b"VE_SGE" , 9), - VE(b"VE_SET_LESS_THAN" , b"VE_SLT" , 10), - VE(b"VE_MULTIPLYX2_ADD" , None , 11), - VE(b"VE_MULTIPLY_CLAMP" , None , 12), - VE(b"VE_FLT2FIX_DX" , None , 13), - VE(b"VE_FLT2FIX_DX_RND" , None , 14), - VE(b"VE_PRED_SET_EQ_PUSH" , None , 15), - VE(b"VE_PRED_SET_GT_PUSH" , None , 16), - VE(b"VE_PRED_SET_GTE_PUSH" , None , 17), - VE(b"VE_PRED_SET_NEQ_PUSH" , None , 18), - VE(b"VE_COND_WRITE_EQ" , None , 19), - VE(b"VE_COND_WRITE_GT" , None , 20), - VE(b"VE_COND_WRITE_GTE" , None , 21), - VE(b"VE_COND_WRITE_NEQ" , None , 22), - VE(b"VE_COND_MUX_EQ" , None , 23), - VE(b"VE_COND_MUX_GT" , None , 24), - VE(b"VE_COND_MUX_GTE" , None , 25), - VE(b"VE_SET_GREATER_THAN" , b"VE_SGT" , 26), - VE(b"VE_SET_EQUAL" , b"VE_SEQ" , 27), - VE(b"VE_SET_NOT_EQUAL" , b"VE_SNE" , 28), -] - -math_operations = [ - # name synonym value - ME(b"MATH_NO_OP" , b"ME_NOP" , 0), - ME(b"ME_EXP_BASE2_DX" , b"ME_EXP" , 1), - ME(b"ME_LOG_BASE2_DX" , b"ME_LOG2", 2), - ME(b"ME_EXP_BASEE_FF" , b"ME_EXPE", 3), - ME(b"ME_LIGHT_COEFF_DX" , None , 4), - ME(b"ME_POWER_FUNC_FF" , b"ME_POW" , 5), - ME(b"ME_RECIP_DX" , b"ME_RCP" , 6), - ME(b"ME_RECIP_FF" , None , 7), - ME(b"ME_RECIP_SQRT_DX" , b"ME_RSQ" , 8), - ME(b"ME_RECIP_SQRT_FF" , None , 9), - ME(b"ME_MULTIPLY" , b"ME_MUL" , 10), - ME(b"ME_EXP_BASE2_FULL_DX" , None , 11), - ME(b"ME_LOG_BASE2_FULL_DX" , None , 12), - ME(b"ME_POWER_FUNC_FF_CLAMP_B" , None , 13), - ME(b"ME_POWER_FUNC_FF_CLAMP_B1" , None , 14), - ME(b"ME_POWER_FUNC_FF_CLAMP_01" , None , 15), - ME(b"ME_SIN" , None , 16), - ME(b"ME_COS" , None , 17), - ME(b"ME_LOG_BASE2_IEEE" , None , 18), - ME(b"ME_RECIP_IEEE" , None , 19), - ME(b"ME_RECIP_SQRT_IEEE" , None , 20), - ME(b"ME_PRED_SET_EQ" , None , 21), - ME(b"ME_PRED_SET_GT" , None , 22), - ME(b"ME_PRED_SET_GTE" , None , 23), - ME(b"ME_PRED_SET_NEQ" , None , 24), - ME(b"ME_PRED_SET_CLR" , None , 25), - ME(b"ME_PRED_SET_INV" , None , 26), - ME(b"ME_PRED_SET_POP" , None , 27), - ME(b"ME_PRED_SET_RESTORE" , None , 28), + opcodes.MATH_NO_OP, + opcodes.ME_EXP_BASE2_DX, + opcodes.ME_LOG_BASE2_DX, + opcodes.ME_EXP_BASEE_FF, + opcodes.ME_LIGHT_COEFF_DX, + opcodes.ME_POWER_FUNC_FF, + opcodes.ME_RECIP_DX, + opcodes.ME_RECIP_FF, + opcodes.ME_RECIP_SQRT_DX, + opcodes.ME_RECIP_SQRT_FF, + opcodes.ME_MULTIPLY, + opcodes.ME_EXP_BASE2_FULL_DX, + opcodes.ME_LOG_BASE2_FULL_DX, + opcodes.ME_POWER_FUNC_FF_CLAMP_B, + opcodes.ME_POWER_FUNC_FF_CLAMP_B1, + opcodes.ME_POWER_FUNC_FF_CLAMP_01, + opcodes.ME_SIN, + opcodes.ME_COS, + opcodes.ME_LOG_BASE2_IEEE, + opcodes.ME_RECIP_IEEE, + opcodes.ME_RECIP_SQRT_IEEE, + opcodes.ME_PRED_SET_EQ, + opcodes.ME_PRED_SET_GT, + opcodes.ME_PRED_SET_GTE, + opcodes.ME_PRED_SET_NEQ, + opcodes.ME_PRED_SET_CLR, + opcodes.ME_PRED_SET_INV, + opcodes.ME_PRED_SET_POP, + opcodes.ME_PRED_SET_RESTORE, ] class KW(Enum): @@ -120,12 +95,9 @@ keywords = [ def find_keyword(b: memoryview): b = bytes(b) - for vector_op in vector_operations: - if vector_op.name == b.upper() or (vector_op.synonym is not None and vector_op.synonym == b.upper()): - return vector_op - for math_op in math_operations: - if math_op.name == b.upper() or (math_op.synonym is not None and math_op.synonym == b.upper()): - return math_op + for op in operations: + if op.name == b.upper() or (op.synonym is not None and op.synonym == b.upper()): + return op for keyword, name, synonym in keywords: if name == b.lower() or (synonym is not None and synonym == b.lower()): return keyword diff --git a/regs/assembler/vs/opcodes.py b/regs/assembler/vs/opcodes.py new file mode 100644 index 0000000..cb85980 --- /dev/null +++ b/regs/assembler/vs/opcodes.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from typing import Optional + +@dataclass +class MVE: + name: str + synonym: Optional[str] + value: int + operand_count: int + +@dataclass +class VE: + name: str + synonym: Optional[str] + value: int + operand_count: int + +@dataclass +class ME: + name: str + synonym: Optional[str] + value: int + operand_count: int + +MACRO_OP_2CLK_MADD = MVE(b"MACRO_OP_2CLK_MADD" , None , 0, 3) +MACRO_OP_2CLK_M2X_ADD = MVE(b"MACRO_OP_2CLK_M2X_ADD" , None , 1, 3) + +VECTOR_NO_OP = VE(b"VECTOR_NO_OP" , b"VE_NOP" , 0, 0) +VE_DOT_PRODUCT = VE(b"VE_DOT_PRODUCT" , b"VE_DOT" , 1, 2) +VE_MULTIPLY = VE(b"VE_MULTIPLY" , b"VE_MUL" , 2, 2) +VE_ADD = VE(b"VE_ADD" , None , 3, 2) +VE_MULTIPLY_ADD = VE(b"VE_MULTIPLY_ADD" , b"VE_MAD" , 4, 3) +VE_DISTANCE_VECTOR = VE(b"VE_DISTANCE_VECTOR" , None , 5, 2) +VE_FRACTION = VE(b"VE_FRACTION" , b"VE_FRC" , 6, 1) +VE_MAXIMUM = VE(b"VE_MAXIMUM" , b"VE_MAX" , 7, 2) +VE_MINIMUM = VE(b"VE_MINIMUM" , b"VE_MIN" , 8, 2) +VE_SET_GREATER_THAN_EQUAL = VE(b"VE_SET_GREATER_THAN_EQUAL" , b"VE_SGE" , 9, 2) +VE_SET_LESS_THAN = VE(b"VE_SET_LESS_THAN" , b"VE_SLT" , 10, 2) +VE_MULTIPLYX2_ADD = VE(b"VE_MULTIPLYX2_ADD" , None , 11, 3) +VE_MULTIPLY_CLAMP = VE(b"VE_MULTIPLY_CLAMP" , None , 12, 3) +VE_FLT2FIX_DX = VE(b"VE_FLT2FIX_DX" , None , 13, 1) +VE_FLT2FIX_DX_RND = VE(b"VE_FLT2FIX_DX_RND" , None , 14, 1) +VE_PRED_SET_EQ_PUSH = VE(b"VE_PRED_SET_EQ_PUSH" , None , 15, 2) +VE_PRED_SET_GT_PUSH = VE(b"VE_PRED_SET_GT_PUSH" , None , 16, 2) +VE_PRED_SET_GTE_PUSH = VE(b"VE_PRED_SET_GTE_PUSH" , None , 17, 2) +VE_PRED_SET_NEQ_PUSH = VE(b"VE_PRED_SET_NEQ_PUSH" , None , 18, 2) +VE_COND_WRITE_EQ = VE(b"VE_COND_WRITE_EQ" , None , 19, 2) +VE_COND_WRITE_GT = VE(b"VE_COND_WRITE_GT" , None , 20, 2) +VE_COND_WRITE_GTE = VE(b"VE_COND_WRITE_GTE" , None , 21, 2) +VE_COND_WRITE_NEQ = VE(b"VE_COND_WRITE_NEQ" , None , 22, 2) +VE_COND_MUX_EQ = VE(b"VE_COND_MUX_EQ" , None , 23, 3) +VE_COND_MUX_GT = VE(b"VE_COND_MUX_GT" , None , 24, 3) +VE_COND_MUX_GTE = VE(b"VE_COND_MUX_GTE" , None , 25, 3) +VE_SET_GREATER_THAN = VE(b"VE_SET_GREATER_THAN" , b"VE_SGT" , 26, 2) +VE_SET_EQUAL = VE(b"VE_SET_EQUAL" , b"VE_SEQ" , 27, 2) +VE_SET_NOT_EQUAL = VE(b"VE_SET_NOT_EQUAL" , b"VE_SNE" , 28, 2) + +MATH_NO_OP = ME(b"MATH_NO_OP" , b"ME_NOP" , 0, 0) +ME_EXP_BASE2_DX = ME(b"ME_EXP_BASE2_DX" , b"ME_EXP" , 1, 1) +ME_LOG_BASE2_DX = ME(b"ME_LOG_BASE2_DX" , b"ME_LOG2", 2, 1) +ME_EXP_BASEE_FF = ME(b"ME_EXP_BASEE_FF" , b"ME_EXPE", 3, 1) +ME_LIGHT_COEFF_DX = ME(b"ME_LIGHT_COEFF_DX" , None , 4, 3) +ME_POWER_FUNC_FF = ME(b"ME_POWER_FUNC_FF" , b"ME_POW" , 5, 2) +ME_RECIP_DX = ME(b"ME_RECIP_DX" , b"ME_RCP" , 6, 1) +ME_RECIP_FF = ME(b"ME_RECIP_FF" , None , 7, 1) +ME_RECIP_SQRT_DX = ME(b"ME_RECIP_SQRT_DX" , b"ME_RSQ" , 8, 1) +ME_RECIP_SQRT_FF = ME(b"ME_RECIP_SQRT_FF" , None , 9, 1) +ME_MULTIPLY = ME(b"ME_MULTIPLY" , b"ME_MUL" , 10, 2) +ME_EXP_BASE2_FULL_DX = ME(b"ME_EXP_BASE2_FULL_DX" , None , 11, 1) +ME_LOG_BASE2_FULL_DX = ME(b"ME_LOG_BASE2_FULL_DX" , None , 12, 1) +ME_POWER_FUNC_FF_CLAMP_B = ME(b"ME_POWER_FUNC_FF_CLAMP_B" , None , 13, 3) +ME_POWER_FUNC_FF_CLAMP_B1 = ME(b"ME_POWER_FUNC_FF_CLAMP_B1" , None , 14, 3) +ME_POWER_FUNC_FF_CLAMP_01 = ME(b"ME_POWER_FUNC_FF_CLAMP_01" , None , 15, 2) +ME_SIN = ME(b"ME_SIN" , None , 16, 1) +ME_COS = ME(b"ME_COS" , None , 17, 1) +ME_LOG_BASE2_IEEE = ME(b"ME_LOG_BASE2_IEEE" , None , 18, 1) +ME_RECIP_IEEE = ME(b"ME_RECIP_IEEE" , None , 19, 1) +ME_RECIP_SQRT_IEEE = ME(b"ME_RECIP_SQRT_IEEE" , None , 20, 1) +ME_PRED_SET_EQ = ME(b"ME_PRED_SET_EQ" , None , 21, 1) +ME_PRED_SET_GT = ME(b"ME_PRED_SET_GT" , None , 22, 1) +ME_PRED_SET_GTE = ME(b"ME_PRED_SET_GTE" , None , 23, 1) +ME_PRED_SET_NEQ = ME(b"ME_PRED_SET_NEQ" , None , 24, 1) +ME_PRED_SET_CLR = ME(b"ME_PRED_SET_CLR" , None , 25, 0) +ME_PRED_SET_INV = ME(b"ME_PRED_SET_INV" , None , 26, 1) +ME_PRED_SET_POP = ME(b"ME_PRED_SET_POP" , None , 27, 1) +ME_PRED_SET_RESTORE = ME(b"ME_PRED_SET_RESTORE" , None , 28, 1) diff --git a/regs/assembler/vs/parser.py b/regs/assembler/vs/parser.py index a887cbd..79c59ca 100644 --- a/regs/assembler/vs/parser.py +++ b/regs/assembler/vs/parser.py @@ -3,193 +3,106 @@ from dataclasses import dataclass from typing import Union from assembler.parser import BaseParser, ParserError -from assembler.lexer import TT -from assembler.vs.keywords import KW, ME, VE, find_keyword - -""" -temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 -temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 -temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___ -temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000 -temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___ -temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000 -temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000 -temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000 -out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_ -out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1 -""" +from assembler.lexer import TT, Token +from assembler.error import print_error @dataclass -class DestinationOp: - type: KW - offset: int - write_enable: set[int] - opcode: Union[VE, ME] - sat: bool - macro: bool - -@dataclass -class SourceSwizzle: - select: tuple[int, int, int, int] - modifier: tuple[bool, bool, bool, bool] +class Destination: + type_keyword: Token + offset_identifier: Token + write_enable_identifier: Token @dataclass class Source: - type: KW - offset: int - swizzle: SourceSwizzle + absolute: bool + type_keyword: Token + offset_identifier: Token + swizzle_identifier: Token + +@dataclass +class Operation: + destination: Destination + opcode_keyword: Token + opcode_suffix_keyword: Token + sources: list[Source] @dataclass class Instruction: - destination_op: DestinationOp - source0: Source - source1: Source - source2: Source - -def identifier_to_number(token): - digits = set(b"0123456789") - - assert token.type is TT.identifier - if not all(d in digits for d in token.lexeme): - raise ParserError("expected number", token) - return int(bytes(token.lexeme), 10) - -def we_ord(c): - if c == ord("w"): - return 3 - else: - return c - ord("x") - -def parse_dest_write_enable(token): - we_chars = set(b"xyzw") - assert token.type is TT.identifier - we = bytes(token.lexeme).lower() - if not all(c in we_chars for c in we): - raise ParserError("expected destination write enable", token) - if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we): - raise ParserError("misleading non-sequential write enable", token) - return set(we_ord(c) for c in we) - -def parse_source_swizzle(token): - select_mapping = { - ord('x'): 0, - ord('y'): 1, - ord('z'): 2, - ord('w'): 3, - ord('0'): 4, - ord('1'): 5, - ord('h'): 6, - ord('_'): 7, - ord('u'): 7, - } - state = 0 - ix = 0 - swizzle_selects = [None] * 4 - swizzle_modifiers = [None] * 4 - lexeme = bytes(token.lexeme).lower() - while state < 4: - if ix >= len(token.lexeme): - raise ParserError("invalid source swizzle", token) - c = lexeme[ix] - if c == ord('-'): - if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None): - raise ParserError("invalid source swizzle modifier", token) - swizzle_modifiers[state] = True - elif c in select_mapping: - if swizzle_selects[state] is not None: - raise ParserError("invalid source swizzle select", token) - swizzle_selects[state] = select_mapping[c] - if swizzle_modifiers[state] is None: - swizzle_modifiers[state] = False - state += 1 - else: - raise ParserError("invalid source swizzle", token) - ix += 1 - if ix != len(lexeme): - raise ParserError("invalid source swizzle", token) - return SourceSwizzle(swizzle_selects, swizzle_modifiers) + operations: list[Operation] class Parser(BaseParser): - def destination_type(self): - token = self.consume(TT.keyword, "expected destination type") - destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} - if token.keyword not in destination_keywords: - raise ParserError("expected destination type", token) - return token.keyword + def destination(self): + type_keyword = self.consume(TT.keyword, "expected destination type keyword") + self.consume(TT.left_square, "expected left square") + offset_identifier = self.consume(TT.identifier, "expected destination offset identifier") + self.consume(TT.right_square, "expected right square") + self.consume(TT.dot, "expected dot") + write_enable_identifier = self.consume(TT.identifier, "expected destination write enable identifier") - def offset(self): - self.consume(TT.left_square, "expected offset") - identifier_token = self.consume(TT.identifier, "expected offset") - value = identifier_to_number(identifier_token) - self.consume(TT.right_square, "expected offset") - return value + return Destination( + type_keyword, + offset_identifier, + write_enable_identifier, + ) - def opcode(self): - token = self.consume(TT.keyword, "expected opcode") - if type(token.keyword) != VE and type(token.keyword) != ME: - raise ParserError("expected opcode", token) - return token.keyword - - def destination_op(self): - destination_type = self.destination_type() - offset_value = self.offset() - self.consume(TT.dot, "expected write enable") - write_enable_token = self.consume(TT.identifier, "expected write enable token") - write_enable = parse_dest_write_enable(write_enable_token) - self.consume(TT.equal, "expected equals") - opcode = self.opcode() - sat = False - if self.match(TT.dot): + def is_absolute(self): + result = self.match(TT.bar) + if result: self.advance() - suffix = self.consume(TT.keyword, "expected saturation suffix") - if suffix.keyword is not KW.saturation: - raise ParserError("expected saturation suffix", token) - sat = True - - macro = False - return DestinationOp(type=destination_type, - offset=offset_value, - write_enable=write_enable, - opcode=opcode, - sat=sat, - macro=macro) - - def source_type(self): - token = self.consume(TT.keyword, "expected source type") - source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary} - if token.keyword not in source_keywords: - raise ParserError("expected source type", token) - return token.keyword - - def source_swizzle(self): - token = self.consume(TT.identifier, "expected source swizzle") - return parse_source_swizzle(token) + return result def source(self): - "input[0].-y-_-0-_" - source_type = self.source_type() - offset = self.offset() - self.consume(TT.dot, "expected source swizzle") - source_swizzle = self.source_swizzle() - return Source(source_type, offset, source_swizzle) + absolute = self.is_absolute() + + type_keyword = self.consume(TT.keyword, "expected source type keyword") + self.consume(TT.left_square, "expected left square") + offset_identifier = self.consume(TT.identifier, "expected source offset identifier") + self.consume(TT.right_square, "expected right square") + self.consume(TT.dot, "expected dot") + swizzle_identifier = self.consume(TT.identifier, "expected source swizzle identifier") + + if absolute: + self.consume(TT.bar, "expected vertical bar") + + return Source( + absolute, + type_keyword, + offset_identifier, + swizzle_identifier, + ) + + def operation(self): + destination = self.destination() + + self.consume(TT.equal, "expected equal") + + opcode_keyword = self.consume(TT.keyword, "expected opcode keyword") + opcode_suffix_keyword = None + if self.match(TT.dot): + self.advance() + opcode_suffix_keyword = self.consume(TT.keyword, "expected opcode suffix keyword") + + sources = [] + while not (self.match(TT.comma) or self.match(TT.semicolon)): + sources.append(self.source()) + + return Operation( + destination, + opcode_keyword, + opcode_suffix_keyword, + sources, + ) def instruction(self): - first_token = self.peek() - destination_op = self.destination_op() - source0 = self.source() - if self.match(TT.semicolon) or self.match(TT.eof): - source1 = None - else: - source1 = self.source() - if self.match(TT.semicolon) or self.match(TT.eof): - source2 = None - else: - source2 = self.source() - last_token = self.peek(-1) + operations = [] + while not self.match(TT.semicolon): + operations.append(self.operation()) + if not self.match(TT.semicolon): + self.consume(TT.comma, "expected comma") + self.consume(TT.semicolon, "expected semicolon") - return ( - Instruction(destination_op, source0, source1, source2), - (first_token.start_ix, last_token.start_ix + len(last_token.lexeme)) + return Instruction( + operations, ) def instructions(self): @@ -198,9 +111,18 @@ class Parser(BaseParser): if __name__ == "__main__": from assembler.lexer import Lexer - buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" - lexer = Lexer(buf, find_keyword) + from assembler.vs.keywords import find_keyword + buf = b""" + out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ , +atemp[0].xz = ME_SIN input[0].-y-_-0-_ ; +""" + lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=False) tokens = list(lexer.lex_tokens()) parser = Parser(tokens) from pprint import pprint - pprint(parser.instruction()) + try: + pprint(parser.instruction()) + except ParserError as e: + print_error(None, buf, e) + raise + print(parser.peek()) diff --git a/regs/assembler/vs/validator.py b/regs/assembler/vs/validator.py index 5b6e118..948d60d 100644 --- a/regs/assembler/vs/validator.py +++ b/regs/assembler/vs/validator.py @@ -1,25 +1,274 @@ -from assembler.vs.keywords import KW, ME, VE, macro_vector_operations +from dataclasses import dataclass +from enum import IntEnum +from itertools import pairwise +from typing import Union -class ValidatorError(Exception): - pass +from assembler.lexer import Token +from assembler.validator import ValidatorError +from assembler.vs.keywords import KW +from assembler.vs import opcodes +from collections import OrderedDict + + +class SourceType(IntEnum): + temporary = 0 + input = 1 + constant = 2 + alt_temporary = 3 + +class SwizzleSelect(IntEnum): + x = 0 + y = 1 + z = 2 + w = 3 + zero = 4 + one = 5 + half = 6 + unused = 7 + +@dataclass +class Source: + type: SourceType + absolute: bool + offset: int + swizzle_selects: tuple[SwizzleSelect, SwizzleSelect, SwizzleSelect, SwizzleSelect] + modifiers: tuple[bool, bool, bool, bool] + +class DestinationType(IntEnum): + temporary = 0 + a0 = 1 + out = 2 + out_repl_x = 3 + alt_temporary = 4 + input = 5 + +@dataclass +class Destination: + type: DestinationType + offset: int + write_enable: tuple[bool, bool, bool, bool] + +@dataclass +class OpcodeDestination: + macro_inst: bool + reg_type: int + +@dataclass +class Instruction: + destination: Destination + saturation: bool + sources: list[Source] + opcode: Union[opcodes.VE, opcodes.ME, opcodes.MVE] + +def validate_opcode(opcode_keyword: Token): + if type(opcode_keyword.keyword) is opcodes.ME: + return opcode_keyword.keyword + elif type(opcode_keyword.keyword) is opcodes.VE: + return opcode_keyword.keyword + else: + raise ValidatorError("invalid opcode keyword", opcode_keyword) + +def validate_identifier_number(token): + try: + return int(token.lexeme, 10) + except ValueError: + raise ValidatorError("invalid number", token) + +def validate_destination_write_enable(write_enable_identifier): + we_chars_s = b"xyzw" + we_chars = {c: i for i, c in enumerate(we_chars_s)} + we = bytes(write_enable_identifier.lexeme).lower() + + if not all(c in we_chars for c in we): + raise ParserError("invalid character in destination write enable", write_enable_identifier) + if not all(we_chars[a] < we_chars[b] for a, b in pairwise(we)) or len(set(we)) != len(we): + raise ParserError("misleading non-sequential destination write enable", write_enable_identifier) + + we = set(we) + return tuple(c in we for c in we_chars_s) + +def validate_destination(destination): + destination_type_keywords = OrderedDict([ + (KW.temporary , (DestinationType.temporary , 128)), # 32 + (KW.a0 , (DestinationType.a0 , 1 )), # ?? + (KW.out , (DestinationType.out , 128)), # ? + (KW.out_repl_x , (DestinationType.out_repl_x , 128)), # ? + (KW.alt_temporary , (DestinationType.alt_temporary, 20 )), # 20 + (KW.input , (DestinationType.input , 128)), # 32 + ]) + if destination.type_keyword.keyword not in destination_type_keywords: + raise ValidatorError("invalid destination type keyword", destination.type_keyword) + type, max_offset = destination_type_keywords[destination.type_keyword.keyword] + offset = validate_identifier_number(destination.offset_identifier) + if offset >= max_offset: + raise ValidatorError("invalid offset value", source.offset_identifier) + + write_enable = validate_destination_write_enable(destination.write_enable_identifier) + + return Destination( + type, + offset, + write_enable + ) + +def parse_swizzle_lexeme(token): + swizzle_select_characters = OrderedDict([ + (ord(b"x"), SwizzleSelect.x), + (ord(b"y"), SwizzleSelect.y), + (ord(b"z"), SwizzleSelect.z), + (ord(b"w"), SwizzleSelect.w), + (ord(b"0"), SwizzleSelect.zero), + (ord(b"1"), SwizzleSelect.one), + (ord(b"h"), SwizzleSelect.half), + (ord(b"_"), SwizzleSelect.unused), + ]) + lexeme = bytes(token.lexeme).lower() + swizzles = [] + modifier = False + for c in lexeme: + if c == ord(b"-"): + modifier = True + else: + if c not in swizzle_select_characters: + raise ValueError(c) + swizzles.append((swizzle_select_characters[c], modifier)) + modifier = False + return tuple(zip(*swizzles)) + +def validate_source(source): + source_type_keywords = OrderedDict([ + (KW.temporary , (SourceType.temporary , 128)), # 32 + (KW.input , (SourceType.input , 128)), # 32 + (KW.constant , (SourceType.constant , 256)), # 256 + (KW.alt_temporary , (SourceType.alt_temporary, 20)), # 20 + ]) + if source.type_keyword.keyword not in source_type_keywords: + raise ValidatorError("invalid source type keyword", source.type_keyword) + + type, max_offset = source_type_keywords[source.type_keyword.keyword] + absolute = source.absolute + offset = validate_identifier_number(source.offset_identifier) + if offset >= max_offset: + raise ValidatorError("invalid offset value", source.offset_identifier) + try: + swizzle_selects, modifiers = parse_swizzle_lexeme(source.swizzle_identifier) + except ValueError: + raise ValidatorError("invalid source swizzle", source.swizzle_identifier) + + return Source( + type, + absolute, + offset, + swizzle_selects, + modifiers + ) + +def addresses_by_type(sources, source_type): + return set(int(source.offset) + for source in sources + if source.type == source_type) + +def source_ix_with_type_reversed(sources, source_type): + for i, source in reversed(list(enumerate(sources))): + if source.type == source_type: + return i + assert False, (sources, source_type) + +def validate_source_address_counts(sources_ast, sources, opcode): + temporary_address_count = len(addresses_by_type(sources, SourceType.temporary)) + assert temporary_address_count >= 0 and temporary_address_count <= 3 + assert type(opcode) in {opcodes.VE, opcodes.ME} + if temporary_address_count == 3: + if opcode == opcodes.VE_MULTIPLY_ADD: + opcode = opcodes.MACRO_OP_2CLK_MADD + elif opcode == opcodes.VE_MULTIPLYX2_ADD: + opcode = opcodes.MACRO_OP_2CLK_M2X_ADD + else: + raise ValidatorError("too many temporary addresses in non-macro operation(s)", + sources_ast[-1].offset_identifier) + + constant_count = len(addresses_by_type(sources, SourceType.constant)) + if constant_count > 1: + source_ix = source_ix_with_type_reversed(sources, SourceType.constant) + raise ValidatorError(f"too many constant addresses in operation(s); expected 1, have {constant_count}", + sources_ast[source_ix].offset_identifier) + + input_count = len(addresses_by_type(sources, SourceType.input)) + if input_count > 1: + source_ix = source_with_type_reversed(sources, SourceType.input) + raise ValidatorError(f"too many input addresses in operation(s); expected 1, have {input_count}", + sources_ast[source_ix].offset_identifier) + + alt_temporary_count = len(addresses_by_type(sources, SourceType.alt_temporary)) + if alt_temporary_count > 1: + source_ix = source_with_type_reversed(sources, SourceType.alt_temporary) + raise ValidatorError(f"too many alt temporary addresses in operation(s); expected 1, have {alt_temporary_count}", + sources_ast[source_ix].offset_identifier) + + return opcode + +def validate_instruction_inner(operation, opcode): + destination = validate_destination(operation.destination) + saturation = False + if operation.opcode_suffix_keyword is not None: + if operation.opcode_suffix_keyword.keyword is not KW.saturation: + raise ValidatorError("invalid opcode saturation suffix", operation.opcode_suffix_keyword) + saturation = True + if len(operation.sources) > 3: + raise ValidatorError("too many sources in operation", operation.sources[0].type_keyword) + if len(operation.sources) != opcode.operand_count: + raise ValidatorError(f"incorrect number of source operands; expected {opcode.operand_count}", operation.sources[0].type_keyword) + + sources = [] + for source in operation.sources: + sources.append(validate_source(source)) + opcode = validate_source_address_counts(operation.sources, sources, opcode) + + return Instruction( + destination, + saturation, + sources, + opcode + ) def validate_instruction(ins): - temp_addresses = len(set( - source.offset - for source in [ins.source0, ins.source1, ins.source2] - if (source is not None and source.type == KW.temporary) - )) - if temp_addresses > 2: - if type(ins.destination_op.opcode) is not VE: - raise ValidatorError("too many addresses for non-VE instruction", ins) - if ins.destination_op.opcode.name not in {b"VE_MULTIPLYX2_ADD", b"VE_MULTIPLY_ADD"}: - raise ValidatorError("too many addresses for VE non-multiply-add instruction", ins) - assert ins.destination_op.macro == False, ins - ins.destination_op.macro = True - if ins.destination_op.opcode.name == b"VE_MULTIPLY_ADD": - ins.destination_op.opcode = macro_vector_operations[0] - elif ins.destination_op.opcode.name == b"VE_MULTIPLYX2_ADD": - ins.destination_op.opcode = macro_vector_operations[1] - else: - assert False - return ins + if len(ins.operations) > 2: + raise ValidatorError("too many operations in instruction", ins.operations[0].destination.type_keyword) + + opcodes = [validate_opcode(operation.opcode_keyword) for operation in ins.operations] + opcode_types = set(type(opcode) for opcode in opcodes) + if len(opcode_types) != len(opcodes): + opcode_type, = opcode_types + raise ValidatorError(f"invalid dual math operation: too many opcodes of type {opcode_type}", ins.operations[0].opcode_keyword) + + if len(opcodes) == 2: + assert False, "not implemented" + #return validate_dual_math_instruction(ins, opcodes) + else: + assert len(opcodes) == 1 + return validate_instruction_inner(ins.operations[0], opcodes[0]) + +if __name__ == "__main__": + from assembler.lexer import Lexer + from assembler.parser import ParserError + from assembler.vs.parser import Parser + from assembler.vs.keywords import find_keyword + from assembler.error import print_error + buf = b""" + out[0].xz = VE_MAD.SAT |temp[1].-y-_0-_| const[2].x_0_ const[2].x_0_ ; +""" +#atemp[0].xz = ME_SIN input[0].-y-_-0-_ ; + + lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=False) + tokens = list(lexer.lex_tokens()) + parser = Parser(tokens) + from pprint import pprint + try: + ins = parser.instruction() + pprint(validate_instruction(ins)) + except ValidatorError as e: + print_error(None, buf, e) + raise + except ParserError as e: + print_error(None, buf, e) + raise