assembler/vs: rewrite parser and validator

This commit is contained in:
Zack Buhman 2025-10-23 19:51:19 -05:00
parent 50244c7c95
commit 8594bc4a38
10 changed files with 581 additions and 378 deletions

View File

@ -1,9 +1,11 @@
import sys import sys
from assembler.lexer import Lexer, LexerError from assembler.lexer import Lexer, LexerError
from assembler.fs.parser import Parser, ParserError from assembler.parser import ParserError
from assembler.validator import ValidatorError
from assembler.fs.parser import Parser
from assembler.fs.keywords import find_keyword from assembler.fs.keywords import find_keyword
from assembler.fs.validator import validate_instruction, ValidatorError from assembler.fs.validator import validate_instruction
from assembler.fs.emitter import emit_instruction from assembler.fs.emitter import emit_instruction
from assembler.error import print_error from assembler.error import print_error

View File

@ -124,7 +124,7 @@ class Parser(BaseParser):
swizzle_identifier = self.consume(TT.identifier, "expected swizzle identifier") swizzle_identifier = self.consume(TT.identifier, "expected swizzle identifier")
if abs: if abs:
self.consume(TT.bar, "expected bar") self.consume(TT.bar, "expected vertical bar")
mod_table = { mod_table = {
# (neg, abs) # (neg, abs)

View File

@ -6,9 +6,7 @@ from collections import OrderedDict
from assembler.fs.parser import Mod from assembler.fs.parser import Mod
from assembler.fs.keywords import _keyword_to_string, KW from assembler.fs.keywords import _keyword_to_string, KW
from assembler.error import print_error from assembler.error import print_error
from assembler.validator import ValidatorError
class ValidatorError(Exception):
pass
class SrcAddrType(Enum): class SrcAddrType(Enum):
temp = auto() temp = auto()

View File

@ -0,0 +1,2 @@
class ValidatorError(Exception):
pass

View File

@ -24,9 +24,9 @@ def frontend_inner(buf):
lexer = Lexer(buf, find_keyword) lexer = Lexer(buf, find_keyword)
tokens = list(lexer.lex_tokens()) tokens = list(lexer.lex_tokens())
parser = Parser(tokens) parser = Parser(tokens)
for ins, start_end in parser.instructions(): for ins in parser.instructions():
ins = validate_instruction(ins) ins = validate_instruction(ins)
yield list(emit_instruction(ins)), start_end yield list(emit_instruction(ins))
def frontend(filename, buf): def frontend(filename, buf):
try: try:
@ -43,10 +43,5 @@ if __name__ == "__main__":
#output_filename = sys.argv[2] #output_filename = sys.argv[2]
with open(input_filename, 'rb') as f: with open(input_filename, 'rb') as f:
buf = f.read() buf = f.read()
output = list(frontend(input_filename, buf)) for cw in frontend(input_filename, buf):
for cw, (start_ix, end_ix) in output:
if True:
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},") print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},")
else:
source = buf[start_ix:end_ix]
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x}, // {source.decode('utf-8')}")

View File

@ -1,92 +1,59 @@
from assembler.vs.keywords import ME, VE, MVE, KW from typing import Union
from assembler.vs.parser import Instruction, DestinationOp, Source
from assembler.vs.opcodes import ME, VE, MVE
from assembler.vs.validator import Destination, Source, Instruction
import pvs_dst import pvs_dst
import pvs_src import pvs_src
import pvs_dst_bits
import pvs_src_bits
def we_x(s): def emit_destination_opcode(destination: Destination,
return int(0 in s) opcode: Union[ME, VE, MVE],
saturation: bool):
assert type(opcode) in {ME, VE, MVE}
math_inst = int(type(opcode) is ME)
macro_inst = int(type(opcode) is MVE)
def we_y(s): ve_sat = int((not math_inst) and saturation)
return int(1 in s) me_sat = int(math_inst and saturation)
def we_z(s):
return int(2 in s)
def we_w(s):
return int(3 in s)
def dst_reg_type(kw):
if kw == KW.temporary:
return pvs_dst_bits.PVS_DST_REG_gen["TEMPORARY"]
elif kw == KW.a0:
return pvs_dst_bits.PVS_DST_REG_gen["A0"]
elif kw == KW.out:
return pvs_dst_bits.PVS_DST_REG_gen["OUT"]
elif kw == KW.out_repl_x:
return pvs_dst_bits.PVS_DST_REG_gen["OUT_REPL_X"]
elif kw == KW.alt_temporary:
return pvs_dst_bits.PVS_DST_REG_gen["ALT_TEMPORARY"]
elif kw == KW.input:
return pvs_dst_bits.PVS_DST_REG_gen["INPUT"]
else:
assert not "Invalid PVS_DST_REG", kw
def emit_destination_op(dst_op: DestinationOp):
assert type(dst_op.opcode) in {ME, VE, MVE}
math_inst = int(type(dst_op.opcode) is ME)
if dst_op.macro:
assert dst_op.opcode.value in {0, 1}
ve_sat = int((not math_inst) and dst_op.sat)
me_sat = int(math_inst and dst_op.sat)
value = ( value = (
pvs_dst.OPCODE_gen(dst_op.opcode.value) pvs_dst.OPCODE_gen(opcode.value)
| pvs_dst.MATH_INST_gen(math_inst) | pvs_dst.MATH_INST_gen(math_inst)
| pvs_dst.MACRO_INST_gen(int(dst_op.macro)) | pvs_dst.MACRO_INST_gen(int(macro_inst))
| pvs_dst.REG_TYPE_gen(dst_reg_type(dst_op.type)) | pvs_dst.REG_TYPE_gen(destination.type.value)
| pvs_dst.OFFSET_gen(dst_op.offset) | pvs_dst.OFFSET_gen(destination.offset)
| pvs_dst.WE_X_gen(we_x(dst_op.write_enable)) | pvs_dst.WE_X_gen(int(destination.write_enable[0]))
| pvs_dst.WE_Y_gen(we_y(dst_op.write_enable)) | pvs_dst.WE_Y_gen(int(destination.write_enable[1]))
| pvs_dst.WE_Z_gen(we_z(dst_op.write_enable)) | pvs_dst.WE_Z_gen(int(destination.write_enable[2]))
| pvs_dst.WE_W_gen(we_w(dst_op.write_enable)) | pvs_dst.WE_W_gen(int(destination.write_enable[3]))
| pvs_dst.VE_SAT_gen(ve_sat) | pvs_dst.VE_SAT_gen(ve_sat)
| pvs_dst.ME_SAT_gen(me_sat) | pvs_dst.ME_SAT_gen(me_sat)
) )
yield value yield value
def src_reg_type(kw):
if kw == KW.temporary:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_TEMPORARY"]
elif kw == KW.input:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_INPUT"]
elif kw == KW.constant:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_CONSTANT"]
elif kw == KW.alt_temporary:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_ALT_TEMPORARY"]
else:
assert not "Invalid PVS_SRC_REG", kw
def emit_source(src: Source, prev: Source): def emit_source(src: Source, prev: Source):
if src is not None: if src is not None:
assert src.offset <= 255
value = ( value = (
pvs_src.REG_TYPE_gen(src_reg_type(src.type)) pvs_src.REG_TYPE_gen(src.type.value)
| pvs_src.OFFSET_gen(src.offset) | pvs_src.OFFSET_gen(src.offset)
| pvs_src.SWIZZLE_X_gen(src.swizzle.select[0]) | pvs_src.ABS_XYZW_gen(int(src.absolute))
| pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1]) | pvs_src.SWIZZLE_X_gen(src.swizzle_selects[0].value)
| pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2]) | pvs_src.SWIZZLE_Y_gen(src.swizzle_selects[1].value)
| pvs_src.SWIZZLE_W_gen(src.swizzle.select[3]) | pvs_src.SWIZZLE_Z_gen(src.swizzle_selects[2].value)
| pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0])) | pvs_src.SWIZZLE_W_gen(src.swizzle_selects[3].value)
| pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1])) | pvs_src.MODIFIER_X_gen(int(src.modifiers[0]))
| pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2])) | pvs_src.MODIFIER_Y_gen(int(src.modifiers[1]))
| pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3])) | pvs_src.MODIFIER_Z_gen(int(src.modifiers[2]))
| pvs_src.MODIFIER_W_gen(int(src.modifiers[3]))
) )
else: else:
assert prev is not None assert prev is not None
assert prev.offset <= 255
value = ( value = (
pvs_src.REG_TYPE_gen(src_reg_type(prev.type)) pvs_src.REG_TYPE_gen(prev.type.value)
| pvs_src.OFFSET_gen(prev.offset) | pvs_src.OFFSET_gen(prev.offset)
| pvs_src.ABS_XYZW_gen(0)
| pvs_src.SWIZZLE_X_gen(7) | pvs_src.SWIZZLE_X_gen(7)
| pvs_src.SWIZZLE_Y_gen(7) | pvs_src.SWIZZLE_Y_gen(7)
| pvs_src.SWIZZLE_Z_gen(7) | pvs_src.SWIZZLE_Z_gen(7)
@ -99,22 +66,32 @@ def emit_source(src: Source, prev: Source):
yield value yield value
def prev_source(ins, ix): def prev_source(ins, ix):
assert ins.sources[0] is not None
if ix == 0: if ix == 0:
assert ins.source0 is not None return ins.sources[0]
return ins.source0
elif ix == 1: elif ix == 1:
return ins.source0 return ins.sources[0]
elif ix == 2: elif ix == 2:
if ins.source1 is not None: if ins.sources[1] is not None:
return ins.source1 return ins.sources[1]
else: else:
return ins.source0 return ins.sources[0]
else: else:
assert False, ix assert False, ix
def emit_instruction(ins: Instruction): def emit_instruction(ins: Instruction):
yield from emit_destination_op(ins.destination_op) yield from emit_destination_opcode(ins.destination,
ins.opcode,
ins.saturation)
yield from emit_source(ins.source0, prev_source(ins, 0)) if len(ins.sources) >= 1:
yield from emit_source(ins.source1, prev_source(ins, 1)) yield from emit_source(ins.sources[0], prev_source(ins, 0))
yield from emit_source(ins.source2, prev_source(ins, 2))
source1 = ins.sources[1] if len(ins.sources) >= 2 else None
source2 = ins.sources[2] if len(ins.sources) >= 3 else None
yield from emit_source(source1, prev_source(ins, 1))
yield from emit_source(source2, prev_source(ins, 2))
else:
yield 0
yield 0
yield 0

View File

@ -2,93 +2,68 @@ from dataclasses import dataclass
from typing import Optional from typing import Optional
from enum import Enum, auto from enum import Enum, auto
@dataclass from assembler.vs import opcodes
class MVE:
name: str
synonym: Optional[str]
value: int
@dataclass operations = [
class VE: opcodes.VECTOR_NO_OP,
name: str opcodes.VE_DOT_PRODUCT,
synonym: Optional[str] opcodes.VE_MULTIPLY,
value: int opcodes.VE_ADD,
opcodes.VE_MULTIPLY_ADD,
opcodes.VE_DISTANCE_VECTOR,
opcodes.VE_FRACTION,
opcodes.VE_MAXIMUM,
opcodes.VE_MINIMUM,
opcodes.VE_SET_GREATER_THAN_EQUAL,
opcodes.VE_SET_LESS_THAN,
opcodes.VE_MULTIPLYX2_ADD,
opcodes.VE_MULTIPLY_CLAMP,
opcodes.VE_FLT2FIX_DX,
opcodes.VE_FLT2FIX_DX_RND,
opcodes.VE_PRED_SET_EQ_PUSH,
opcodes.VE_PRED_SET_GT_PUSH,
opcodes.VE_PRED_SET_GTE_PUSH,
opcodes.VE_PRED_SET_NEQ_PUSH,
opcodes.VE_COND_WRITE_EQ,
opcodes.VE_COND_WRITE_GT,
opcodes.VE_COND_WRITE_GTE,
opcodes.VE_COND_WRITE_NEQ,
opcodes.VE_COND_MUX_EQ,
opcodes.VE_COND_MUX_GT,
opcodes.VE_COND_MUX_GTE,
opcodes.VE_SET_GREATER_THAN,
opcodes.VE_SET_EQUAL,
opcodes.VE_SET_NOT_EQUAL,
@dataclass opcodes.MATH_NO_OP,
class ME: opcodes.ME_EXP_BASE2_DX,
name: str opcodes.ME_LOG_BASE2_DX,
synonym: Optional[str] opcodes.ME_EXP_BASEE_FF,
value: int opcodes.ME_LIGHT_COEFF_DX,
opcodes.ME_POWER_FUNC_FF,
macro_vector_operations = [ opcodes.ME_RECIP_DX,
MVE(b"MACRO_OP_2CLK_MADD" , None , 0), opcodes.ME_RECIP_FF,
MVE(b"MACRO_OP_2CLK_M2X_ADD" , None , 1), opcodes.ME_RECIP_SQRT_DX,
] opcodes.ME_RECIP_SQRT_FF,
opcodes.ME_MULTIPLY,
vector_operations = [ opcodes.ME_EXP_BASE2_FULL_DX,
# name synonym value opcodes.ME_LOG_BASE2_FULL_DX,
VE(b"VECTOR_NO_OP" , b"VE_NOP" , 0), opcodes.ME_POWER_FUNC_FF_CLAMP_B,
VE(b"VE_DOT_PRODUCT" , b"VE_DOT" , 1), opcodes.ME_POWER_FUNC_FF_CLAMP_B1,
VE(b"VE_MULTIPLY" , b"VE_MUL" , 2), opcodes.ME_POWER_FUNC_FF_CLAMP_01,
VE(b"VE_ADD" , None , 3), opcodes.ME_SIN,
VE(b"VE_MULTIPLY_ADD" , b"VE_MAD" , 4), opcodes.ME_COS,
VE(b"VE_DISTANCE_VECTOR" , None , 5), opcodes.ME_LOG_BASE2_IEEE,
VE(b"VE_FRACTION" , b"VE_FRC" , 6), opcodes.ME_RECIP_IEEE,
VE(b"VE_MAXIMUM" , b"VE_MAX" , 7), opcodes.ME_RECIP_SQRT_IEEE,
VE(b"VE_MINIMUM" , b"VE_MIN" , 8), opcodes.ME_PRED_SET_EQ,
VE(b"VE_SET_GREATER_THAN_EQUAL" , b"VE_SGE" , 9), opcodes.ME_PRED_SET_GT,
VE(b"VE_SET_LESS_THAN" , b"VE_SLT" , 10), opcodes.ME_PRED_SET_GTE,
VE(b"VE_MULTIPLYX2_ADD" , None , 11), opcodes.ME_PRED_SET_NEQ,
VE(b"VE_MULTIPLY_CLAMP" , None , 12), opcodes.ME_PRED_SET_CLR,
VE(b"VE_FLT2FIX_DX" , None , 13), opcodes.ME_PRED_SET_INV,
VE(b"VE_FLT2FIX_DX_RND" , None , 14), opcodes.ME_PRED_SET_POP,
VE(b"VE_PRED_SET_EQ_PUSH" , None , 15), opcodes.ME_PRED_SET_RESTORE,
VE(b"VE_PRED_SET_GT_PUSH" , None , 16),
VE(b"VE_PRED_SET_GTE_PUSH" , None , 17),
VE(b"VE_PRED_SET_NEQ_PUSH" , None , 18),
VE(b"VE_COND_WRITE_EQ" , None , 19),
VE(b"VE_COND_WRITE_GT" , None , 20),
VE(b"VE_COND_WRITE_GTE" , None , 21),
VE(b"VE_COND_WRITE_NEQ" , None , 22),
VE(b"VE_COND_MUX_EQ" , None , 23),
VE(b"VE_COND_MUX_GT" , None , 24),
VE(b"VE_COND_MUX_GTE" , None , 25),
VE(b"VE_SET_GREATER_THAN" , b"VE_SGT" , 26),
VE(b"VE_SET_EQUAL" , b"VE_SEQ" , 27),
VE(b"VE_SET_NOT_EQUAL" , b"VE_SNE" , 28),
]
math_operations = [
# name synonym value
ME(b"MATH_NO_OP" , b"ME_NOP" , 0),
ME(b"ME_EXP_BASE2_DX" , b"ME_EXP" , 1),
ME(b"ME_LOG_BASE2_DX" , b"ME_LOG2", 2),
ME(b"ME_EXP_BASEE_FF" , b"ME_EXPE", 3),
ME(b"ME_LIGHT_COEFF_DX" , None , 4),
ME(b"ME_POWER_FUNC_FF" , b"ME_POW" , 5),
ME(b"ME_RECIP_DX" , b"ME_RCP" , 6),
ME(b"ME_RECIP_FF" , None , 7),
ME(b"ME_RECIP_SQRT_DX" , b"ME_RSQ" , 8),
ME(b"ME_RECIP_SQRT_FF" , None , 9),
ME(b"ME_MULTIPLY" , b"ME_MUL" , 10),
ME(b"ME_EXP_BASE2_FULL_DX" , None , 11),
ME(b"ME_LOG_BASE2_FULL_DX" , None , 12),
ME(b"ME_POWER_FUNC_FF_CLAMP_B" , None , 13),
ME(b"ME_POWER_FUNC_FF_CLAMP_B1" , None , 14),
ME(b"ME_POWER_FUNC_FF_CLAMP_01" , None , 15),
ME(b"ME_SIN" , None , 16),
ME(b"ME_COS" , None , 17),
ME(b"ME_LOG_BASE2_IEEE" , None , 18),
ME(b"ME_RECIP_IEEE" , None , 19),
ME(b"ME_RECIP_SQRT_IEEE" , None , 20),
ME(b"ME_PRED_SET_EQ" , None , 21),
ME(b"ME_PRED_SET_GT" , None , 22),
ME(b"ME_PRED_SET_GTE" , None , 23),
ME(b"ME_PRED_SET_NEQ" , None , 24),
ME(b"ME_PRED_SET_CLR" , None , 25),
ME(b"ME_PRED_SET_INV" , None , 26),
ME(b"ME_PRED_SET_POP" , None , 27),
ME(b"ME_PRED_SET_RESTORE" , None , 28),
] ]
class KW(Enum): class KW(Enum):
@ -120,12 +95,9 @@ keywords = [
def find_keyword(b: memoryview): def find_keyword(b: memoryview):
b = bytes(b) b = bytes(b)
for vector_op in vector_operations: for op in operations:
if vector_op.name == b.upper() or (vector_op.synonym is not None and vector_op.synonym == b.upper()): if op.name == b.upper() or (op.synonym is not None and op.synonym == b.upper()):
return vector_op return op
for math_op in math_operations:
if math_op.name == b.upper() or (math_op.synonym is not None and math_op.synonym == b.upper()):
return math_op
for keyword, name, synonym in keywords: for keyword, name, synonym in keywords:
if name == b.lower() or (synonym is not None and synonym == b.lower()): if name == b.lower() or (synonym is not None and synonym == b.lower()):
return keyword return keyword

View File

@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class MVE:
name: str
synonym: Optional[str]
value: int
operand_count: int
@dataclass
class VE:
name: str
synonym: Optional[str]
value: int
operand_count: int
@dataclass
class ME:
name: str
synonym: Optional[str]
value: int
operand_count: int
MACRO_OP_2CLK_MADD = MVE(b"MACRO_OP_2CLK_MADD" , None , 0, 3)
MACRO_OP_2CLK_M2X_ADD = MVE(b"MACRO_OP_2CLK_M2X_ADD" , None , 1, 3)
VECTOR_NO_OP = VE(b"VECTOR_NO_OP" , b"VE_NOP" , 0, 0)
VE_DOT_PRODUCT = VE(b"VE_DOT_PRODUCT" , b"VE_DOT" , 1, 2)
VE_MULTIPLY = VE(b"VE_MULTIPLY" , b"VE_MUL" , 2, 2)
VE_ADD = VE(b"VE_ADD" , None , 3, 2)
VE_MULTIPLY_ADD = VE(b"VE_MULTIPLY_ADD" , b"VE_MAD" , 4, 3)
VE_DISTANCE_VECTOR = VE(b"VE_DISTANCE_VECTOR" , None , 5, 2)
VE_FRACTION = VE(b"VE_FRACTION" , b"VE_FRC" , 6, 1)
VE_MAXIMUM = VE(b"VE_MAXIMUM" , b"VE_MAX" , 7, 2)
VE_MINIMUM = VE(b"VE_MINIMUM" , b"VE_MIN" , 8, 2)
VE_SET_GREATER_THAN_EQUAL = VE(b"VE_SET_GREATER_THAN_EQUAL" , b"VE_SGE" , 9, 2)
VE_SET_LESS_THAN = VE(b"VE_SET_LESS_THAN" , b"VE_SLT" , 10, 2)
VE_MULTIPLYX2_ADD = VE(b"VE_MULTIPLYX2_ADD" , None , 11, 3)
VE_MULTIPLY_CLAMP = VE(b"VE_MULTIPLY_CLAMP" , None , 12, 3)
VE_FLT2FIX_DX = VE(b"VE_FLT2FIX_DX" , None , 13, 1)
VE_FLT2FIX_DX_RND = VE(b"VE_FLT2FIX_DX_RND" , None , 14, 1)
VE_PRED_SET_EQ_PUSH = VE(b"VE_PRED_SET_EQ_PUSH" , None , 15, 2)
VE_PRED_SET_GT_PUSH = VE(b"VE_PRED_SET_GT_PUSH" , None , 16, 2)
VE_PRED_SET_GTE_PUSH = VE(b"VE_PRED_SET_GTE_PUSH" , None , 17, 2)
VE_PRED_SET_NEQ_PUSH = VE(b"VE_PRED_SET_NEQ_PUSH" , None , 18, 2)
VE_COND_WRITE_EQ = VE(b"VE_COND_WRITE_EQ" , None , 19, 2)
VE_COND_WRITE_GT = VE(b"VE_COND_WRITE_GT" , None , 20, 2)
VE_COND_WRITE_GTE = VE(b"VE_COND_WRITE_GTE" , None , 21, 2)
VE_COND_WRITE_NEQ = VE(b"VE_COND_WRITE_NEQ" , None , 22, 2)
VE_COND_MUX_EQ = VE(b"VE_COND_MUX_EQ" , None , 23, 3)
VE_COND_MUX_GT = VE(b"VE_COND_MUX_GT" , None , 24, 3)
VE_COND_MUX_GTE = VE(b"VE_COND_MUX_GTE" , None , 25, 3)
VE_SET_GREATER_THAN = VE(b"VE_SET_GREATER_THAN" , b"VE_SGT" , 26, 2)
VE_SET_EQUAL = VE(b"VE_SET_EQUAL" , b"VE_SEQ" , 27, 2)
VE_SET_NOT_EQUAL = VE(b"VE_SET_NOT_EQUAL" , b"VE_SNE" , 28, 2)
MATH_NO_OP = ME(b"MATH_NO_OP" , b"ME_NOP" , 0, 0)
ME_EXP_BASE2_DX = ME(b"ME_EXP_BASE2_DX" , b"ME_EXP" , 1, 1)
ME_LOG_BASE2_DX = ME(b"ME_LOG_BASE2_DX" , b"ME_LOG2", 2, 1)
ME_EXP_BASEE_FF = ME(b"ME_EXP_BASEE_FF" , b"ME_EXPE", 3, 1)
ME_LIGHT_COEFF_DX = ME(b"ME_LIGHT_COEFF_DX" , None , 4, 3)
ME_POWER_FUNC_FF = ME(b"ME_POWER_FUNC_FF" , b"ME_POW" , 5, 2)
ME_RECIP_DX = ME(b"ME_RECIP_DX" , b"ME_RCP" , 6, 1)
ME_RECIP_FF = ME(b"ME_RECIP_FF" , None , 7, 1)
ME_RECIP_SQRT_DX = ME(b"ME_RECIP_SQRT_DX" , b"ME_RSQ" , 8, 1)
ME_RECIP_SQRT_FF = ME(b"ME_RECIP_SQRT_FF" , None , 9, 1)
ME_MULTIPLY = ME(b"ME_MULTIPLY" , b"ME_MUL" , 10, 2)
ME_EXP_BASE2_FULL_DX = ME(b"ME_EXP_BASE2_FULL_DX" , None , 11, 1)
ME_LOG_BASE2_FULL_DX = ME(b"ME_LOG_BASE2_FULL_DX" , None , 12, 1)
ME_POWER_FUNC_FF_CLAMP_B = ME(b"ME_POWER_FUNC_FF_CLAMP_B" , None , 13, 3)
ME_POWER_FUNC_FF_CLAMP_B1 = ME(b"ME_POWER_FUNC_FF_CLAMP_B1" , None , 14, 3)
ME_POWER_FUNC_FF_CLAMP_01 = ME(b"ME_POWER_FUNC_FF_CLAMP_01" , None , 15, 2)
ME_SIN = ME(b"ME_SIN" , None , 16, 1)
ME_COS = ME(b"ME_COS" , None , 17, 1)
ME_LOG_BASE2_IEEE = ME(b"ME_LOG_BASE2_IEEE" , None , 18, 1)
ME_RECIP_IEEE = ME(b"ME_RECIP_IEEE" , None , 19, 1)
ME_RECIP_SQRT_IEEE = ME(b"ME_RECIP_SQRT_IEEE" , None , 20, 1)
ME_PRED_SET_EQ = ME(b"ME_PRED_SET_EQ" , None , 21, 1)
ME_PRED_SET_GT = ME(b"ME_PRED_SET_GT" , None , 22, 1)
ME_PRED_SET_GTE = ME(b"ME_PRED_SET_GTE" , None , 23, 1)
ME_PRED_SET_NEQ = ME(b"ME_PRED_SET_NEQ" , None , 24, 1)
ME_PRED_SET_CLR = ME(b"ME_PRED_SET_CLR" , None , 25, 0)
ME_PRED_SET_INV = ME(b"ME_PRED_SET_INV" , None , 26, 1)
ME_PRED_SET_POP = ME(b"ME_PRED_SET_POP" , None , 27, 1)
ME_PRED_SET_RESTORE = ME(b"ME_PRED_SET_RESTORE" , None , 28, 1)

View File

@ -3,193 +3,106 @@ from dataclasses import dataclass
from typing import Union from typing import Union
from assembler.parser import BaseParser, ParserError from assembler.parser import BaseParser, ParserError
from assembler.lexer import TT from assembler.lexer import TT, Token
from assembler.vs.keywords import KW, ME, VE, find_keyword from assembler.error import print_error
"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
@dataclass @dataclass
class DestinationOp: class Destination:
type: KW type_keyword: Token
offset: int offset_identifier: Token
write_enable: set[int] write_enable_identifier: Token
opcode: Union[VE, ME]
sat: bool
macro: bool
@dataclass
class SourceSwizzle:
select: tuple[int, int, int, int]
modifier: tuple[bool, bool, bool, bool]
@dataclass @dataclass
class Source: class Source:
type: KW absolute: bool
offset: int type_keyword: Token
swizzle: SourceSwizzle offset_identifier: Token
swizzle_identifier: Token
@dataclass
class Operation:
destination: Destination
opcode_keyword: Token
opcode_suffix_keyword: Token
sources: list[Source]
@dataclass @dataclass
class Instruction: class Instruction:
destination_op: DestinationOp operations: list[Operation]
source0: Source
source1: Source
source2: Source
def identifier_to_number(token):
digits = set(b"0123456789")
assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme):
raise ParserError("expected number", token)
return int(bytes(token.lexeme), 10)
def we_ord(c):
if c == ord("w"):
return 3
else:
return c - ord("x")
def parse_dest_write_enable(token):
we_chars = set(b"xyzw")
assert token.type is TT.identifier
we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("expected destination write enable", token)
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential write enable", token)
return set(we_ord(c) for c in we)
def parse_source_swizzle(token):
select_mapping = {
ord('x'): 0,
ord('y'): 1,
ord('z'): 2,
ord('w'): 3,
ord('0'): 4,
ord('1'): 5,
ord('h'): 6,
ord('_'): 7,
ord('u'): 7,
}
state = 0
ix = 0
swizzle_selects = [None] * 4
swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower()
while state < 4:
if ix >= len(token.lexeme):
raise ParserError("invalid source swizzle", token)
c = lexeme[ix]
if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParserError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True
elif c in select_mapping:
if swizzle_selects[state] is not None:
raise ParserError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False
state += 1
else:
raise ParserError("invalid source swizzle", token)
ix += 1
if ix != len(lexeme):
raise ParserError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
class Parser(BaseParser): class Parser(BaseParser):
def destination_type(self): def destination(self):
token = self.consume(TT.keyword, "expected destination type") type_keyword = self.consume(TT.keyword, "expected destination type keyword")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} self.consume(TT.left_square, "expected left square")
if token.keyword not in destination_keywords: offset_identifier = self.consume(TT.identifier, "expected destination offset identifier")
raise ParserError("expected destination type", token) self.consume(TT.right_square, "expected right square")
return token.keyword self.consume(TT.dot, "expected dot")
write_enable_identifier = self.consume(TT.identifier, "expected destination write enable identifier")
def offset(self): return Destination(
self.consume(TT.left_square, "expected offset") type_keyword,
identifier_token = self.consume(TT.identifier, "expected offset") offset_identifier,
value = identifier_to_number(identifier_token) write_enable_identifier,
self.consume(TT.right_square, "expected offset") )
return value
def opcode(self): def is_absolute(self):
token = self.consume(TT.keyword, "expected opcode") result = self.match(TT.bar)
if type(token.keyword) != VE and type(token.keyword) != ME: if result:
raise ParserError("expected opcode", token)
return token.keyword
def destination_op(self):
destination_type = self.destination_type()
offset_value = self.offset()
self.consume(TT.dot, "expected write enable")
write_enable_token = self.consume(TT.identifier, "expected write enable token")
write_enable = parse_dest_write_enable(write_enable_token)
self.consume(TT.equal, "expected equals")
opcode = self.opcode()
sat = False
if self.match(TT.dot):
self.advance() self.advance()
suffix = self.consume(TT.keyword, "expected saturation suffix") return result
if suffix.keyword is not KW.saturation:
raise ParserError("expected saturation suffix", token)
sat = True
macro = False
return DestinationOp(type=destination_type,
offset=offset_value,
write_enable=write_enable,
opcode=opcode,
sat=sat,
macro=macro)
def source_type(self):
token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords:
raise ParserError("expected source type", token)
return token.keyword
def source_swizzle(self):
token = self.consume(TT.identifier, "expected source swizzle")
return parse_source_swizzle(token)
def source(self): def source(self):
"input[0].-y-_-0-_" absolute = self.is_absolute()
source_type = self.source_type()
offset = self.offset() type_keyword = self.consume(TT.keyword, "expected source type keyword")
self.consume(TT.dot, "expected source swizzle") self.consume(TT.left_square, "expected left square")
source_swizzle = self.source_swizzle() offset_identifier = self.consume(TT.identifier, "expected source offset identifier")
return Source(source_type, offset, source_swizzle) self.consume(TT.right_square, "expected right square")
self.consume(TT.dot, "expected dot")
swizzle_identifier = self.consume(TT.identifier, "expected source swizzle identifier")
if absolute:
self.consume(TT.bar, "expected vertical bar")
return Source(
absolute,
type_keyword,
offset_identifier,
swizzle_identifier,
)
def operation(self):
destination = self.destination()
self.consume(TT.equal, "expected equal")
opcode_keyword = self.consume(TT.keyword, "expected opcode keyword")
opcode_suffix_keyword = None
if self.match(TT.dot):
self.advance()
opcode_suffix_keyword = self.consume(TT.keyword, "expected opcode suffix keyword")
sources = []
while not (self.match(TT.comma) or self.match(TT.semicolon)):
sources.append(self.source())
return Operation(
destination,
opcode_keyword,
opcode_suffix_keyword,
sources,
)
def instruction(self): def instruction(self):
first_token = self.peek() operations = []
destination_op = self.destination_op() while not self.match(TT.semicolon):
source0 = self.source() operations.append(self.operation())
if self.match(TT.semicolon) or self.match(TT.eof): if not self.match(TT.semicolon):
source1 = None self.consume(TT.comma, "expected comma")
else:
source1 = self.source()
if self.match(TT.semicolon) or self.match(TT.eof):
source2 = None
else:
source2 = self.source()
last_token = self.peek(-1)
self.consume(TT.semicolon, "expected semicolon") self.consume(TT.semicolon, "expected semicolon")
return ( return Instruction(
Instruction(destination_op, source0, source1, source2), operations,
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
) )
def instructions(self): def instructions(self):
@ -198,9 +111,18 @@ class Parser(BaseParser):
if __name__ == "__main__": if __name__ == "__main__":
from assembler.lexer import Lexer from assembler.lexer import Lexer
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" from assembler.vs.keywords import find_keyword
lexer = Lexer(buf, find_keyword) buf = b"""
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ ,
atemp[0].xz = ME_SIN input[0].-y-_-0-_ ;
"""
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=False)
tokens = list(lexer.lex_tokens()) tokens = list(lexer.lex_tokens())
parser = Parser(tokens) parser = Parser(tokens)
from pprint import pprint from pprint import pprint
try:
pprint(parser.instruction()) pprint(parser.instruction())
except ParserError as e:
print_error(None, buf, e)
raise
print(parser.peek())

View File

@ -1,25 +1,274 @@
from assembler.vs.keywords import KW, ME, VE, macro_vector_operations from dataclasses import dataclass
from enum import IntEnum
from itertools import pairwise
from typing import Union
class ValidatorError(Exception): from assembler.lexer import Token
pass from assembler.validator import ValidatorError
from assembler.vs.keywords import KW
from assembler.vs import opcodes
from collections import OrderedDict
class SourceType(IntEnum):
temporary = 0
input = 1
constant = 2
alt_temporary = 3
class SwizzleSelect(IntEnum):
x = 0
y = 1
z = 2
w = 3
zero = 4
one = 5
half = 6
unused = 7
@dataclass
class Source:
type: SourceType
absolute: bool
offset: int
swizzle_selects: tuple[SwizzleSelect, SwizzleSelect, SwizzleSelect, SwizzleSelect]
modifiers: tuple[bool, bool, bool, bool]
class DestinationType(IntEnum):
temporary = 0
a0 = 1
out = 2
out_repl_x = 3
alt_temporary = 4
input = 5
@dataclass
class Destination:
type: DestinationType
offset: int
write_enable: tuple[bool, bool, bool, bool]
@dataclass
class OpcodeDestination:
macro_inst: bool
reg_type: int
@dataclass
class Instruction:
destination: Destination
saturation: bool
sources: list[Source]
opcode: Union[opcodes.VE, opcodes.ME, opcodes.MVE]
def validate_opcode(opcode_keyword: Token):
if type(opcode_keyword.keyword) is opcodes.ME:
return opcode_keyword.keyword
elif type(opcode_keyword.keyword) is opcodes.VE:
return opcode_keyword.keyword
else:
raise ValidatorError("invalid opcode keyword", opcode_keyword)
def validate_identifier_number(token):
try:
return int(token.lexeme, 10)
except ValueError:
raise ValidatorError("invalid number", token)
def validate_destination_write_enable(write_enable_identifier):
we_chars_s = b"xyzw"
we_chars = {c: i for i, c in enumerate(we_chars_s)}
we = bytes(write_enable_identifier.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("invalid character in destination write enable", write_enable_identifier)
if not all(we_chars[a] < we_chars[b] for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential destination write enable", write_enable_identifier)
we = set(we)
return tuple(c in we for c in we_chars_s)
def validate_destination(destination):
destination_type_keywords = OrderedDict([
(KW.temporary , (DestinationType.temporary , 128)), # 32
(KW.a0 , (DestinationType.a0 , 1 )), # ??
(KW.out , (DestinationType.out , 128)), # ?
(KW.out_repl_x , (DestinationType.out_repl_x , 128)), # ?
(KW.alt_temporary , (DestinationType.alt_temporary, 20 )), # 20
(KW.input , (DestinationType.input , 128)), # 32
])
if destination.type_keyword.keyword not in destination_type_keywords:
raise ValidatorError("invalid destination type keyword", destination.type_keyword)
type, max_offset = destination_type_keywords[destination.type_keyword.keyword]
offset = validate_identifier_number(destination.offset_identifier)
if offset >= max_offset:
raise ValidatorError("invalid offset value", source.offset_identifier)
write_enable = validate_destination_write_enable(destination.write_enable_identifier)
return Destination(
type,
offset,
write_enable
)
def parse_swizzle_lexeme(token):
swizzle_select_characters = OrderedDict([
(ord(b"x"), SwizzleSelect.x),
(ord(b"y"), SwizzleSelect.y),
(ord(b"z"), SwizzleSelect.z),
(ord(b"w"), SwizzleSelect.w),
(ord(b"0"), SwizzleSelect.zero),
(ord(b"1"), SwizzleSelect.one),
(ord(b"h"), SwizzleSelect.half),
(ord(b"_"), SwizzleSelect.unused),
])
lexeme = bytes(token.lexeme).lower()
swizzles = []
modifier = False
for c in lexeme:
if c == ord(b"-"):
modifier = True
else:
if c not in swizzle_select_characters:
raise ValueError(c)
swizzles.append((swizzle_select_characters[c], modifier))
modifier = False
return tuple(zip(*swizzles))
def validate_source(source):
source_type_keywords = OrderedDict([
(KW.temporary , (SourceType.temporary , 128)), # 32
(KW.input , (SourceType.input , 128)), # 32
(KW.constant , (SourceType.constant , 256)), # 256
(KW.alt_temporary , (SourceType.alt_temporary, 20)), # 20
])
if source.type_keyword.keyword not in source_type_keywords:
raise ValidatorError("invalid source type keyword", source.type_keyword)
type, max_offset = source_type_keywords[source.type_keyword.keyword]
absolute = source.absolute
offset = validate_identifier_number(source.offset_identifier)
if offset >= max_offset:
raise ValidatorError("invalid offset value", source.offset_identifier)
try:
swizzle_selects, modifiers = parse_swizzle_lexeme(source.swizzle_identifier)
except ValueError:
raise ValidatorError("invalid source swizzle", source.swizzle_identifier)
return Source(
type,
absolute,
offset,
swizzle_selects,
modifiers
)
def addresses_by_type(sources, source_type):
return set(int(source.offset)
for source in sources
if source.type == source_type)
def source_ix_with_type_reversed(sources, source_type):
for i, source in reversed(list(enumerate(sources))):
if source.type == source_type:
return i
assert False, (sources, source_type)
def validate_source_address_counts(sources_ast, sources, opcode):
temporary_address_count = len(addresses_by_type(sources, SourceType.temporary))
assert temporary_address_count >= 0 and temporary_address_count <= 3
assert type(opcode) in {opcodes.VE, opcodes.ME}
if temporary_address_count == 3:
if opcode == opcodes.VE_MULTIPLY_ADD:
opcode = opcodes.MACRO_OP_2CLK_MADD
elif opcode == opcodes.VE_MULTIPLYX2_ADD:
opcode = opcodes.MACRO_OP_2CLK_M2X_ADD
else:
raise ValidatorError("too many temporary addresses in non-macro operation(s)",
sources_ast[-1].offset_identifier)
constant_count = len(addresses_by_type(sources, SourceType.constant))
if constant_count > 1:
source_ix = source_ix_with_type_reversed(sources, SourceType.constant)
raise ValidatorError(f"too many constant addresses in operation(s); expected 1, have {constant_count}",
sources_ast[source_ix].offset_identifier)
input_count = len(addresses_by_type(sources, SourceType.input))
if input_count > 1:
source_ix = source_with_type_reversed(sources, SourceType.input)
raise ValidatorError(f"too many input addresses in operation(s); expected 1, have {input_count}",
sources_ast[source_ix].offset_identifier)
alt_temporary_count = len(addresses_by_type(sources, SourceType.alt_temporary))
if alt_temporary_count > 1:
source_ix = source_with_type_reversed(sources, SourceType.alt_temporary)
raise ValidatorError(f"too many alt temporary addresses in operation(s); expected 1, have {alt_temporary_count}",
sources_ast[source_ix].offset_identifier)
return opcode
def validate_instruction_inner(operation, opcode):
destination = validate_destination(operation.destination)
saturation = False
if operation.opcode_suffix_keyword is not None:
if operation.opcode_suffix_keyword.keyword is not KW.saturation:
raise ValidatorError("invalid opcode saturation suffix", operation.opcode_suffix_keyword)
saturation = True
if len(operation.sources) > 3:
raise ValidatorError("too many sources in operation", operation.sources[0].type_keyword)
if len(operation.sources) != opcode.operand_count:
raise ValidatorError(f"incorrect number of source operands; expected {opcode.operand_count}", operation.sources[0].type_keyword)
sources = []
for source in operation.sources:
sources.append(validate_source(source))
opcode = validate_source_address_counts(operation.sources, sources, opcode)
return Instruction(
destination,
saturation,
sources,
opcode
)
def validate_instruction(ins): def validate_instruction(ins):
temp_addresses = len(set( if len(ins.operations) > 2:
source.offset raise ValidatorError("too many operations in instruction", ins.operations[0].destination.type_keyword)
for source in [ins.source0, ins.source1, ins.source2]
if (source is not None and source.type == KW.temporary) opcodes = [validate_opcode(operation.opcode_keyword) for operation in ins.operations]
)) opcode_types = set(type(opcode) for opcode in opcodes)
if temp_addresses > 2: if len(opcode_types) != len(opcodes):
if type(ins.destination_op.opcode) is not VE: opcode_type, = opcode_types
raise ValidatorError("too many addresses for non-VE instruction", ins) raise ValidatorError(f"invalid dual math operation: too many opcodes of type {opcode_type}", ins.operations[0].opcode_keyword)
if ins.destination_op.opcode.name not in {b"VE_MULTIPLYX2_ADD", b"VE_MULTIPLY_ADD"}:
raise ValidatorError("too many addresses for VE non-multiply-add instruction", ins) if len(opcodes) == 2:
assert ins.destination_op.macro == False, ins assert False, "not implemented"
ins.destination_op.macro = True #return validate_dual_math_instruction(ins, opcodes)
if ins.destination_op.opcode.name == b"VE_MULTIPLY_ADD":
ins.destination_op.opcode = macro_vector_operations[0]
elif ins.destination_op.opcode.name == b"VE_MULTIPLYX2_ADD":
ins.destination_op.opcode = macro_vector_operations[1]
else: else:
assert False assert len(opcodes) == 1
return ins return validate_instruction_inner(ins.operations[0], opcodes[0])
if __name__ == "__main__":
from assembler.lexer import Lexer
from assembler.parser import ParserError
from assembler.vs.parser import Parser
from assembler.vs.keywords import find_keyword
from assembler.error import print_error
buf = b"""
out[0].xz = VE_MAD.SAT |temp[1].-y-_0-_| const[2].x_0_ const[2].x_0_ ;
"""
#atemp[0].xz = ME_SIN input[0].-y-_-0-_ ;
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=False)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
try:
ins = parser.instruction()
pprint(validate_instruction(ins))
except ValidatorError as e:
print_error(None, buf, e)
raise
except ParserError as e:
print_error(None, buf, e)
raise