assembler/vs: rewrite parser and validator

This commit is contained in:
Zack Buhman 2025-10-23 19:51:19 -05:00
parent 50244c7c95
commit 8594bc4a38
10 changed files with 581 additions and 378 deletions

View File

@ -1,9 +1,11 @@
import sys
from assembler.lexer import Lexer, LexerError
from assembler.fs.parser import Parser, ParserError
from assembler.parser import ParserError
from assembler.validator import ValidatorError
from assembler.fs.parser import Parser
from assembler.fs.keywords import find_keyword
from assembler.fs.validator import validate_instruction, ValidatorError
from assembler.fs.validator import validate_instruction
from assembler.fs.emitter import emit_instruction
from assembler.error import print_error

View File

@ -124,7 +124,7 @@ class Parser(BaseParser):
swizzle_identifier = self.consume(TT.identifier, "expected swizzle identifier")
if abs:
self.consume(TT.bar, "expected bar")
self.consume(TT.bar, "expected vertical bar")
mod_table = {
# (neg, abs)

View File

@ -6,9 +6,7 @@ from collections import OrderedDict
from assembler.fs.parser import Mod
from assembler.fs.keywords import _keyword_to_string, KW
from assembler.error import print_error
class ValidatorError(Exception):
pass
from assembler.validator import ValidatorError
class SrcAddrType(Enum):
temp = auto()

View File

@ -0,0 +1,2 @@
class ValidatorError(Exception):
pass

View File

@ -24,9 +24,9 @@ def frontend_inner(buf):
lexer = Lexer(buf, find_keyword)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
for ins, start_end in parser.instructions():
for ins in parser.instructions():
ins = validate_instruction(ins)
yield list(emit_instruction(ins)), start_end
yield list(emit_instruction(ins))
def frontend(filename, buf):
try:
@ -43,10 +43,5 @@ if __name__ == "__main__":
#output_filename = sys.argv[2]
with open(input_filename, 'rb') as f:
buf = f.read()
output = list(frontend(input_filename, buf))
for cw, (start_ix, end_ix) in output:
if True:
for cw in frontend(input_filename, buf):
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},")
else:
source = buf[start_ix:end_ix]
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x}, // {source.decode('utf-8')}")

View File

@ -1,92 +1,59 @@
from assembler.vs.keywords import ME, VE, MVE, KW
from assembler.vs.parser import Instruction, DestinationOp, Source
from typing import Union
from assembler.vs.opcodes import ME, VE, MVE
from assembler.vs.validator import Destination, Source, Instruction
import pvs_dst
import pvs_src
import pvs_dst_bits
import pvs_src_bits
def we_x(s):
return int(0 in s)
def emit_destination_opcode(destination: Destination,
opcode: Union[ME, VE, MVE],
saturation: bool):
assert type(opcode) in {ME, VE, MVE}
math_inst = int(type(opcode) is ME)
macro_inst = int(type(opcode) is MVE)
def we_y(s):
return int(1 in s)
def we_z(s):
return int(2 in s)
def we_w(s):
return int(3 in s)
def dst_reg_type(kw):
if kw == KW.temporary:
return pvs_dst_bits.PVS_DST_REG_gen["TEMPORARY"]
elif kw == KW.a0:
return pvs_dst_bits.PVS_DST_REG_gen["A0"]
elif kw == KW.out:
return pvs_dst_bits.PVS_DST_REG_gen["OUT"]
elif kw == KW.out_repl_x:
return pvs_dst_bits.PVS_DST_REG_gen["OUT_REPL_X"]
elif kw == KW.alt_temporary:
return pvs_dst_bits.PVS_DST_REG_gen["ALT_TEMPORARY"]
elif kw == KW.input:
return pvs_dst_bits.PVS_DST_REG_gen["INPUT"]
else:
assert not "Invalid PVS_DST_REG", kw
def emit_destination_op(dst_op: DestinationOp):
assert type(dst_op.opcode) in {ME, VE, MVE}
math_inst = int(type(dst_op.opcode) is ME)
if dst_op.macro:
assert dst_op.opcode.value in {0, 1}
ve_sat = int((not math_inst) and dst_op.sat)
me_sat = int(math_inst and dst_op.sat)
ve_sat = int((not math_inst) and saturation)
me_sat = int(math_inst and saturation)
value = (
pvs_dst.OPCODE_gen(dst_op.opcode.value)
pvs_dst.OPCODE_gen(opcode.value)
| pvs_dst.MATH_INST_gen(math_inst)
| pvs_dst.MACRO_INST_gen(int(dst_op.macro))
| pvs_dst.REG_TYPE_gen(dst_reg_type(dst_op.type))
| pvs_dst.OFFSET_gen(dst_op.offset)
| pvs_dst.WE_X_gen(we_x(dst_op.write_enable))
| pvs_dst.WE_Y_gen(we_y(dst_op.write_enable))
| pvs_dst.WE_Z_gen(we_z(dst_op.write_enable))
| pvs_dst.WE_W_gen(we_w(dst_op.write_enable))
| pvs_dst.MACRO_INST_gen(int(macro_inst))
| pvs_dst.REG_TYPE_gen(destination.type.value)
| pvs_dst.OFFSET_gen(destination.offset)
| pvs_dst.WE_X_gen(int(destination.write_enable[0]))
| pvs_dst.WE_Y_gen(int(destination.write_enable[1]))
| pvs_dst.WE_Z_gen(int(destination.write_enable[2]))
| pvs_dst.WE_W_gen(int(destination.write_enable[3]))
| pvs_dst.VE_SAT_gen(ve_sat)
| pvs_dst.ME_SAT_gen(me_sat)
)
yield value
def src_reg_type(kw):
if kw == KW.temporary:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_TEMPORARY"]
elif kw == KW.input:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_INPUT"]
elif kw == KW.constant:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_CONSTANT"]
elif kw == KW.alt_temporary:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_ALT_TEMPORARY"]
else:
assert not "Invalid PVS_SRC_REG", kw
def emit_source(src: Source, prev: Source):
if src is not None:
assert src.offset <= 255
value = (
pvs_src.REG_TYPE_gen(src_reg_type(src.type))
pvs_src.REG_TYPE_gen(src.type.value)
| pvs_src.OFFSET_gen(src.offset)
| pvs_src.SWIZZLE_X_gen(src.swizzle.select[0])
| pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1])
| pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2])
| pvs_src.SWIZZLE_W_gen(src.swizzle.select[3])
| pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0]))
| pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1]))
| pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2]))
| pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3]))
| pvs_src.ABS_XYZW_gen(int(src.absolute))
| pvs_src.SWIZZLE_X_gen(src.swizzle_selects[0].value)
| pvs_src.SWIZZLE_Y_gen(src.swizzle_selects[1].value)
| pvs_src.SWIZZLE_Z_gen(src.swizzle_selects[2].value)
| pvs_src.SWIZZLE_W_gen(src.swizzle_selects[3].value)
| pvs_src.MODIFIER_X_gen(int(src.modifiers[0]))
| pvs_src.MODIFIER_Y_gen(int(src.modifiers[1]))
| pvs_src.MODIFIER_Z_gen(int(src.modifiers[2]))
| pvs_src.MODIFIER_W_gen(int(src.modifiers[3]))
)
else:
assert prev is not None
assert prev.offset <= 255
value = (
pvs_src.REG_TYPE_gen(src_reg_type(prev.type))
pvs_src.REG_TYPE_gen(prev.type.value)
| pvs_src.OFFSET_gen(prev.offset)
| pvs_src.ABS_XYZW_gen(0)
| pvs_src.SWIZZLE_X_gen(7)
| pvs_src.SWIZZLE_Y_gen(7)
| pvs_src.SWIZZLE_Z_gen(7)
@ -99,22 +66,32 @@ def emit_source(src: Source, prev: Source):
yield value
def prev_source(ins, ix):
assert ins.sources[0] is not None
if ix == 0:
assert ins.source0 is not None
return ins.source0
return ins.sources[0]
elif ix == 1:
return ins.source0
return ins.sources[0]
elif ix == 2:
if ins.source1 is not None:
return ins.source1
if ins.sources[1] is not None:
return ins.sources[1]
else:
return ins.source0
return ins.sources[0]
else:
assert False, ix
def emit_instruction(ins: Instruction):
yield from emit_destination_op(ins.destination_op)
yield from emit_destination_opcode(ins.destination,
ins.opcode,
ins.saturation)
yield from emit_source(ins.source0, prev_source(ins, 0))
yield from emit_source(ins.source1, prev_source(ins, 1))
yield from emit_source(ins.source2, prev_source(ins, 2))
if len(ins.sources) >= 1:
yield from emit_source(ins.sources[0], prev_source(ins, 0))
source1 = ins.sources[1] if len(ins.sources) >= 2 else None
source2 = ins.sources[2] if len(ins.sources) >= 3 else None
yield from emit_source(source1, prev_source(ins, 1))
yield from emit_source(source2, prev_source(ins, 2))
else:
yield 0
yield 0
yield 0

View File

@ -2,93 +2,68 @@ from dataclasses import dataclass
from typing import Optional
from enum import Enum, auto
@dataclass
class MVE:
name: str
synonym: Optional[str]
value: int
from assembler.vs import opcodes
@dataclass
class VE:
name: str
synonym: Optional[str]
value: int
operations = [
opcodes.VECTOR_NO_OP,
opcodes.VE_DOT_PRODUCT,
opcodes.VE_MULTIPLY,
opcodes.VE_ADD,
opcodes.VE_MULTIPLY_ADD,
opcodes.VE_DISTANCE_VECTOR,
opcodes.VE_FRACTION,
opcodes.VE_MAXIMUM,
opcodes.VE_MINIMUM,
opcodes.VE_SET_GREATER_THAN_EQUAL,
opcodes.VE_SET_LESS_THAN,
opcodes.VE_MULTIPLYX2_ADD,
opcodes.VE_MULTIPLY_CLAMP,
opcodes.VE_FLT2FIX_DX,
opcodes.VE_FLT2FIX_DX_RND,
opcodes.VE_PRED_SET_EQ_PUSH,
opcodes.VE_PRED_SET_GT_PUSH,
opcodes.VE_PRED_SET_GTE_PUSH,
opcodes.VE_PRED_SET_NEQ_PUSH,
opcodes.VE_COND_WRITE_EQ,
opcodes.VE_COND_WRITE_GT,
opcodes.VE_COND_WRITE_GTE,
opcodes.VE_COND_WRITE_NEQ,
opcodes.VE_COND_MUX_EQ,
opcodes.VE_COND_MUX_GT,
opcodes.VE_COND_MUX_GTE,
opcodes.VE_SET_GREATER_THAN,
opcodes.VE_SET_EQUAL,
opcodes.VE_SET_NOT_EQUAL,
@dataclass
class ME:
name: str
synonym: Optional[str]
value: int
macro_vector_operations = [
MVE(b"MACRO_OP_2CLK_MADD" , None , 0),
MVE(b"MACRO_OP_2CLK_M2X_ADD" , None , 1),
]
vector_operations = [
# name synonym value
VE(b"VECTOR_NO_OP" , b"VE_NOP" , 0),
VE(b"VE_DOT_PRODUCT" , b"VE_DOT" , 1),
VE(b"VE_MULTIPLY" , b"VE_MUL" , 2),
VE(b"VE_ADD" , None , 3),
VE(b"VE_MULTIPLY_ADD" , b"VE_MAD" , 4),
VE(b"VE_DISTANCE_VECTOR" , None , 5),
VE(b"VE_FRACTION" , b"VE_FRC" , 6),
VE(b"VE_MAXIMUM" , b"VE_MAX" , 7),
VE(b"VE_MINIMUM" , b"VE_MIN" , 8),
VE(b"VE_SET_GREATER_THAN_EQUAL" , b"VE_SGE" , 9),
VE(b"VE_SET_LESS_THAN" , b"VE_SLT" , 10),
VE(b"VE_MULTIPLYX2_ADD" , None , 11),
VE(b"VE_MULTIPLY_CLAMP" , None , 12),
VE(b"VE_FLT2FIX_DX" , None , 13),
VE(b"VE_FLT2FIX_DX_RND" , None , 14),
VE(b"VE_PRED_SET_EQ_PUSH" , None , 15),
VE(b"VE_PRED_SET_GT_PUSH" , None , 16),
VE(b"VE_PRED_SET_GTE_PUSH" , None , 17),
VE(b"VE_PRED_SET_NEQ_PUSH" , None , 18),
VE(b"VE_COND_WRITE_EQ" , None , 19),
VE(b"VE_COND_WRITE_GT" , None , 20),
VE(b"VE_COND_WRITE_GTE" , None , 21),
VE(b"VE_COND_WRITE_NEQ" , None , 22),
VE(b"VE_COND_MUX_EQ" , None , 23),
VE(b"VE_COND_MUX_GT" , None , 24),
VE(b"VE_COND_MUX_GTE" , None , 25),
VE(b"VE_SET_GREATER_THAN" , b"VE_SGT" , 26),
VE(b"VE_SET_EQUAL" , b"VE_SEQ" , 27),
VE(b"VE_SET_NOT_EQUAL" , b"VE_SNE" , 28),
]
math_operations = [
# name synonym value
ME(b"MATH_NO_OP" , b"ME_NOP" , 0),
ME(b"ME_EXP_BASE2_DX" , b"ME_EXP" , 1),
ME(b"ME_LOG_BASE2_DX" , b"ME_LOG2", 2),
ME(b"ME_EXP_BASEE_FF" , b"ME_EXPE", 3),
ME(b"ME_LIGHT_COEFF_DX" , None , 4),
ME(b"ME_POWER_FUNC_FF" , b"ME_POW" , 5),
ME(b"ME_RECIP_DX" , b"ME_RCP" , 6),
ME(b"ME_RECIP_FF" , None , 7),
ME(b"ME_RECIP_SQRT_DX" , b"ME_RSQ" , 8),
ME(b"ME_RECIP_SQRT_FF" , None , 9),
ME(b"ME_MULTIPLY" , b"ME_MUL" , 10),
ME(b"ME_EXP_BASE2_FULL_DX" , None , 11),
ME(b"ME_LOG_BASE2_FULL_DX" , None , 12),
ME(b"ME_POWER_FUNC_FF_CLAMP_B" , None , 13),
ME(b"ME_POWER_FUNC_FF_CLAMP_B1" , None , 14),
ME(b"ME_POWER_FUNC_FF_CLAMP_01" , None , 15),
ME(b"ME_SIN" , None , 16),
ME(b"ME_COS" , None , 17),
ME(b"ME_LOG_BASE2_IEEE" , None , 18),
ME(b"ME_RECIP_IEEE" , None , 19),
ME(b"ME_RECIP_SQRT_IEEE" , None , 20),
ME(b"ME_PRED_SET_EQ" , None , 21),
ME(b"ME_PRED_SET_GT" , None , 22),
ME(b"ME_PRED_SET_GTE" , None , 23),
ME(b"ME_PRED_SET_NEQ" , None , 24),
ME(b"ME_PRED_SET_CLR" , None , 25),
ME(b"ME_PRED_SET_INV" , None , 26),
ME(b"ME_PRED_SET_POP" , None , 27),
ME(b"ME_PRED_SET_RESTORE" , None , 28),
opcodes.MATH_NO_OP,
opcodes.ME_EXP_BASE2_DX,
opcodes.ME_LOG_BASE2_DX,
opcodes.ME_EXP_BASEE_FF,
opcodes.ME_LIGHT_COEFF_DX,
opcodes.ME_POWER_FUNC_FF,
opcodes.ME_RECIP_DX,
opcodes.ME_RECIP_FF,
opcodes.ME_RECIP_SQRT_DX,
opcodes.ME_RECIP_SQRT_FF,
opcodes.ME_MULTIPLY,
opcodes.ME_EXP_BASE2_FULL_DX,
opcodes.ME_LOG_BASE2_FULL_DX,
opcodes.ME_POWER_FUNC_FF_CLAMP_B,
opcodes.ME_POWER_FUNC_FF_CLAMP_B1,
opcodes.ME_POWER_FUNC_FF_CLAMP_01,
opcodes.ME_SIN,
opcodes.ME_COS,
opcodes.ME_LOG_BASE2_IEEE,
opcodes.ME_RECIP_IEEE,
opcodes.ME_RECIP_SQRT_IEEE,
opcodes.ME_PRED_SET_EQ,
opcodes.ME_PRED_SET_GT,
opcodes.ME_PRED_SET_GTE,
opcodes.ME_PRED_SET_NEQ,
opcodes.ME_PRED_SET_CLR,
opcodes.ME_PRED_SET_INV,
opcodes.ME_PRED_SET_POP,
opcodes.ME_PRED_SET_RESTORE,
]
class KW(Enum):
@ -120,12 +95,9 @@ keywords = [
def find_keyword(b: memoryview):
b = bytes(b)
for vector_op in vector_operations:
if vector_op.name == b.upper() or (vector_op.synonym is not None and vector_op.synonym == b.upper()):
return vector_op
for math_op in math_operations:
if math_op.name == b.upper() or (math_op.synonym is not None and math_op.synonym == b.upper()):
return math_op
for op in operations:
if op.name == b.upper() or (op.synonym is not None and op.synonym == b.upper()):
return op
for keyword, name, synonym in keywords:
if name == b.lower() or (synonym is not None and synonym == b.lower()):
return keyword

View File

@ -0,0 +1,86 @@
from dataclasses import dataclass
from typing import Optional
@dataclass
class MVE:
name: str
synonym: Optional[str]
value: int
operand_count: int
@dataclass
class VE:
name: str
synonym: Optional[str]
value: int
operand_count: int
@dataclass
class ME:
name: str
synonym: Optional[str]
value: int
operand_count: int
MACRO_OP_2CLK_MADD = MVE(b"MACRO_OP_2CLK_MADD" , None , 0, 3)
MACRO_OP_2CLK_M2X_ADD = MVE(b"MACRO_OP_2CLK_M2X_ADD" , None , 1, 3)
VECTOR_NO_OP = VE(b"VECTOR_NO_OP" , b"VE_NOP" , 0, 0)
VE_DOT_PRODUCT = VE(b"VE_DOT_PRODUCT" , b"VE_DOT" , 1, 2)
VE_MULTIPLY = VE(b"VE_MULTIPLY" , b"VE_MUL" , 2, 2)
VE_ADD = VE(b"VE_ADD" , None , 3, 2)
VE_MULTIPLY_ADD = VE(b"VE_MULTIPLY_ADD" , b"VE_MAD" , 4, 3)
VE_DISTANCE_VECTOR = VE(b"VE_DISTANCE_VECTOR" , None , 5, 2)
VE_FRACTION = VE(b"VE_FRACTION" , b"VE_FRC" , 6, 1)
VE_MAXIMUM = VE(b"VE_MAXIMUM" , b"VE_MAX" , 7, 2)
VE_MINIMUM = VE(b"VE_MINIMUM" , b"VE_MIN" , 8, 2)
VE_SET_GREATER_THAN_EQUAL = VE(b"VE_SET_GREATER_THAN_EQUAL" , b"VE_SGE" , 9, 2)
VE_SET_LESS_THAN = VE(b"VE_SET_LESS_THAN" , b"VE_SLT" , 10, 2)
VE_MULTIPLYX2_ADD = VE(b"VE_MULTIPLYX2_ADD" , None , 11, 3)
VE_MULTIPLY_CLAMP = VE(b"VE_MULTIPLY_CLAMP" , None , 12, 3)
VE_FLT2FIX_DX = VE(b"VE_FLT2FIX_DX" , None , 13, 1)
VE_FLT2FIX_DX_RND = VE(b"VE_FLT2FIX_DX_RND" , None , 14, 1)
VE_PRED_SET_EQ_PUSH = VE(b"VE_PRED_SET_EQ_PUSH" , None , 15, 2)
VE_PRED_SET_GT_PUSH = VE(b"VE_PRED_SET_GT_PUSH" , None , 16, 2)
VE_PRED_SET_GTE_PUSH = VE(b"VE_PRED_SET_GTE_PUSH" , None , 17, 2)
VE_PRED_SET_NEQ_PUSH = VE(b"VE_PRED_SET_NEQ_PUSH" , None , 18, 2)
VE_COND_WRITE_EQ = VE(b"VE_COND_WRITE_EQ" , None , 19, 2)
VE_COND_WRITE_GT = VE(b"VE_COND_WRITE_GT" , None , 20, 2)
VE_COND_WRITE_GTE = VE(b"VE_COND_WRITE_GTE" , None , 21, 2)
VE_COND_WRITE_NEQ = VE(b"VE_COND_WRITE_NEQ" , None , 22, 2)
VE_COND_MUX_EQ = VE(b"VE_COND_MUX_EQ" , None , 23, 3)
VE_COND_MUX_GT = VE(b"VE_COND_MUX_GT" , None , 24, 3)
VE_COND_MUX_GTE = VE(b"VE_COND_MUX_GTE" , None , 25, 3)
VE_SET_GREATER_THAN = VE(b"VE_SET_GREATER_THAN" , b"VE_SGT" , 26, 2)
VE_SET_EQUAL = VE(b"VE_SET_EQUAL" , b"VE_SEQ" , 27, 2)
VE_SET_NOT_EQUAL = VE(b"VE_SET_NOT_EQUAL" , b"VE_SNE" , 28, 2)
MATH_NO_OP = ME(b"MATH_NO_OP" , b"ME_NOP" , 0, 0)
ME_EXP_BASE2_DX = ME(b"ME_EXP_BASE2_DX" , b"ME_EXP" , 1, 1)
ME_LOG_BASE2_DX = ME(b"ME_LOG_BASE2_DX" , b"ME_LOG2", 2, 1)
ME_EXP_BASEE_FF = ME(b"ME_EXP_BASEE_FF" , b"ME_EXPE", 3, 1)
ME_LIGHT_COEFF_DX = ME(b"ME_LIGHT_COEFF_DX" , None , 4, 3)
ME_POWER_FUNC_FF = ME(b"ME_POWER_FUNC_FF" , b"ME_POW" , 5, 2)
ME_RECIP_DX = ME(b"ME_RECIP_DX" , b"ME_RCP" , 6, 1)
ME_RECIP_FF = ME(b"ME_RECIP_FF" , None , 7, 1)
ME_RECIP_SQRT_DX = ME(b"ME_RECIP_SQRT_DX" , b"ME_RSQ" , 8, 1)
ME_RECIP_SQRT_FF = ME(b"ME_RECIP_SQRT_FF" , None , 9, 1)
ME_MULTIPLY = ME(b"ME_MULTIPLY" , b"ME_MUL" , 10, 2)
ME_EXP_BASE2_FULL_DX = ME(b"ME_EXP_BASE2_FULL_DX" , None , 11, 1)
ME_LOG_BASE2_FULL_DX = ME(b"ME_LOG_BASE2_FULL_DX" , None , 12, 1)
ME_POWER_FUNC_FF_CLAMP_B = ME(b"ME_POWER_FUNC_FF_CLAMP_B" , None , 13, 3)
ME_POWER_FUNC_FF_CLAMP_B1 = ME(b"ME_POWER_FUNC_FF_CLAMP_B1" , None , 14, 3)
ME_POWER_FUNC_FF_CLAMP_01 = ME(b"ME_POWER_FUNC_FF_CLAMP_01" , None , 15, 2)
ME_SIN = ME(b"ME_SIN" , None , 16, 1)
ME_COS = ME(b"ME_COS" , None , 17, 1)
ME_LOG_BASE2_IEEE = ME(b"ME_LOG_BASE2_IEEE" , None , 18, 1)
ME_RECIP_IEEE = ME(b"ME_RECIP_IEEE" , None , 19, 1)
ME_RECIP_SQRT_IEEE = ME(b"ME_RECIP_SQRT_IEEE" , None , 20, 1)
ME_PRED_SET_EQ = ME(b"ME_PRED_SET_EQ" , None , 21, 1)
ME_PRED_SET_GT = ME(b"ME_PRED_SET_GT" , None , 22, 1)
ME_PRED_SET_GTE = ME(b"ME_PRED_SET_GTE" , None , 23, 1)
ME_PRED_SET_NEQ = ME(b"ME_PRED_SET_NEQ" , None , 24, 1)
ME_PRED_SET_CLR = ME(b"ME_PRED_SET_CLR" , None , 25, 0)
ME_PRED_SET_INV = ME(b"ME_PRED_SET_INV" , None , 26, 1)
ME_PRED_SET_POP = ME(b"ME_PRED_SET_POP" , None , 27, 1)
ME_PRED_SET_RESTORE = ME(b"ME_PRED_SET_RESTORE" , None , 28, 1)

View File

@ -3,193 +3,106 @@ from dataclasses import dataclass
from typing import Union
from assembler.parser import BaseParser, ParserError
from assembler.lexer import TT
from assembler.vs.keywords import KW, ME, VE, find_keyword
"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
from assembler.lexer import TT, Token
from assembler.error import print_error
@dataclass
class DestinationOp:
type: KW
offset: int
write_enable: set[int]
opcode: Union[VE, ME]
sat: bool
macro: bool
@dataclass
class SourceSwizzle:
select: tuple[int, int, int, int]
modifier: tuple[bool, bool, bool, bool]
class Destination:
type_keyword: Token
offset_identifier: Token
write_enable_identifier: Token
@dataclass
class Source:
type: KW
offset: int
swizzle: SourceSwizzle
absolute: bool
type_keyword: Token
offset_identifier: Token
swizzle_identifier: Token
@dataclass
class Operation:
destination: Destination
opcode_keyword: Token
opcode_suffix_keyword: Token
sources: list[Source]
@dataclass
class Instruction:
destination_op: DestinationOp
source0: Source
source1: Source
source2: Source
def identifier_to_number(token):
digits = set(b"0123456789")
assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme):
raise ParserError("expected number", token)
return int(bytes(token.lexeme), 10)
def we_ord(c):
if c == ord("w"):
return 3
else:
return c - ord("x")
def parse_dest_write_enable(token):
we_chars = set(b"xyzw")
assert token.type is TT.identifier
we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("expected destination write enable", token)
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential write enable", token)
return set(we_ord(c) for c in we)
def parse_source_swizzle(token):
select_mapping = {
ord('x'): 0,
ord('y'): 1,
ord('z'): 2,
ord('w'): 3,
ord('0'): 4,
ord('1'): 5,
ord('h'): 6,
ord('_'): 7,
ord('u'): 7,
}
state = 0
ix = 0
swizzle_selects = [None] * 4
swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower()
while state < 4:
if ix >= len(token.lexeme):
raise ParserError("invalid source swizzle", token)
c = lexeme[ix]
if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParserError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True
elif c in select_mapping:
if swizzle_selects[state] is not None:
raise ParserError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False
state += 1
else:
raise ParserError("invalid source swizzle", token)
ix += 1
if ix != len(lexeme):
raise ParserError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
operations: list[Operation]
class Parser(BaseParser):
def destination_type(self):
token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
if token.keyword not in destination_keywords:
raise ParserError("expected destination type", token)
return token.keyword
def destination(self):
type_keyword = self.consume(TT.keyword, "expected destination type keyword")
self.consume(TT.left_square, "expected left square")
offset_identifier = self.consume(TT.identifier, "expected destination offset identifier")
self.consume(TT.right_square, "expected right square")
self.consume(TT.dot, "expected dot")
write_enable_identifier = self.consume(TT.identifier, "expected destination write enable identifier")
def offset(self):
self.consume(TT.left_square, "expected offset")
identifier_token = self.consume(TT.identifier, "expected offset")
value = identifier_to_number(identifier_token)
self.consume(TT.right_square, "expected offset")
return value
return Destination(
type_keyword,
offset_identifier,
write_enable_identifier,
)
def opcode(self):
token = self.consume(TT.keyword, "expected opcode")
if type(token.keyword) != VE and type(token.keyword) != ME:
raise ParserError("expected opcode", token)
return token.keyword
def destination_op(self):
destination_type = self.destination_type()
offset_value = self.offset()
self.consume(TT.dot, "expected write enable")
write_enable_token = self.consume(TT.identifier, "expected write enable token")
write_enable = parse_dest_write_enable(write_enable_token)
self.consume(TT.equal, "expected equals")
opcode = self.opcode()
sat = False
if self.match(TT.dot):
def is_absolute(self):
result = self.match(TT.bar)
if result:
self.advance()
suffix = self.consume(TT.keyword, "expected saturation suffix")
if suffix.keyword is not KW.saturation:
raise ParserError("expected saturation suffix", token)
sat = True
macro = False
return DestinationOp(type=destination_type,
offset=offset_value,
write_enable=write_enable,
opcode=opcode,
sat=sat,
macro=macro)
def source_type(self):
token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords:
raise ParserError("expected source type", token)
return token.keyword
def source_swizzle(self):
token = self.consume(TT.identifier, "expected source swizzle")
return parse_source_swizzle(token)
return result
def source(self):
"input[0].-y-_-0-_"
source_type = self.source_type()
offset = self.offset()
self.consume(TT.dot, "expected source swizzle")
source_swizzle = self.source_swizzle()
return Source(source_type, offset, source_swizzle)
absolute = self.is_absolute()
type_keyword = self.consume(TT.keyword, "expected source type keyword")
self.consume(TT.left_square, "expected left square")
offset_identifier = self.consume(TT.identifier, "expected source offset identifier")
self.consume(TT.right_square, "expected right square")
self.consume(TT.dot, "expected dot")
swizzle_identifier = self.consume(TT.identifier, "expected source swizzle identifier")
if absolute:
self.consume(TT.bar, "expected vertical bar")
return Source(
absolute,
type_keyword,
offset_identifier,
swizzle_identifier,
)
def operation(self):
destination = self.destination()
self.consume(TT.equal, "expected equal")
opcode_keyword = self.consume(TT.keyword, "expected opcode keyword")
opcode_suffix_keyword = None
if self.match(TT.dot):
self.advance()
opcode_suffix_keyword = self.consume(TT.keyword, "expected opcode suffix keyword")
sources = []
while not (self.match(TT.comma) or self.match(TT.semicolon)):
sources.append(self.source())
return Operation(
destination,
opcode_keyword,
opcode_suffix_keyword,
sources,
)
def instruction(self):
first_token = self.peek()
destination_op = self.destination_op()
source0 = self.source()
if self.match(TT.semicolon) or self.match(TT.eof):
source1 = None
else:
source1 = self.source()
if self.match(TT.semicolon) or self.match(TT.eof):
source2 = None
else:
source2 = self.source()
last_token = self.peek(-1)
operations = []
while not self.match(TT.semicolon):
operations.append(self.operation())
if not self.match(TT.semicolon):
self.consume(TT.comma, "expected comma")
self.consume(TT.semicolon, "expected semicolon")
return (
Instruction(destination_op, source0, source1, source2),
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
return Instruction(
operations,
)
def instructions(self):
@ -198,9 +111,18 @@ class Parser(BaseParser):
if __name__ == "__main__":
from assembler.lexer import Lexer
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(buf, find_keyword)
from assembler.vs.keywords import find_keyword
buf = b"""
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ ,
atemp[0].xz = ME_SIN input[0].-y-_-0-_ ;
"""
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=False)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
try:
pprint(parser.instruction())
except ParserError as e:
print_error(None, buf, e)
raise
print(parser.peek())

View File

@ -1,25 +1,274 @@
from assembler.vs.keywords import KW, ME, VE, macro_vector_operations
from dataclasses import dataclass
from enum import IntEnum
from itertools import pairwise
from typing import Union
class ValidatorError(Exception):
pass
from assembler.lexer import Token
from assembler.validator import ValidatorError
from assembler.vs.keywords import KW
from assembler.vs import opcodes
from collections import OrderedDict
class SourceType(IntEnum):
temporary = 0
input = 1
constant = 2
alt_temporary = 3
class SwizzleSelect(IntEnum):
x = 0
y = 1
z = 2
w = 3
zero = 4
one = 5
half = 6
unused = 7
@dataclass
class Source:
type: SourceType
absolute: bool
offset: int
swizzle_selects: tuple[SwizzleSelect, SwizzleSelect, SwizzleSelect, SwizzleSelect]
modifiers: tuple[bool, bool, bool, bool]
class DestinationType(IntEnum):
temporary = 0
a0 = 1
out = 2
out_repl_x = 3
alt_temporary = 4
input = 5
@dataclass
class Destination:
type: DestinationType
offset: int
write_enable: tuple[bool, bool, bool, bool]
@dataclass
class OpcodeDestination:
macro_inst: bool
reg_type: int
@dataclass
class Instruction:
destination: Destination
saturation: bool
sources: list[Source]
opcode: Union[opcodes.VE, opcodes.ME, opcodes.MVE]
def validate_opcode(opcode_keyword: Token):
if type(opcode_keyword.keyword) is opcodes.ME:
return opcode_keyword.keyword
elif type(opcode_keyword.keyword) is opcodes.VE:
return opcode_keyword.keyword
else:
raise ValidatorError("invalid opcode keyword", opcode_keyword)
def validate_identifier_number(token):
try:
return int(token.lexeme, 10)
except ValueError:
raise ValidatorError("invalid number", token)
def validate_destination_write_enable(write_enable_identifier):
we_chars_s = b"xyzw"
we_chars = {c: i for i, c in enumerate(we_chars_s)}
we = bytes(write_enable_identifier.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("invalid character in destination write enable", write_enable_identifier)
if not all(we_chars[a] < we_chars[b] for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential destination write enable", write_enable_identifier)
we = set(we)
return tuple(c in we for c in we_chars_s)
def validate_destination(destination):
destination_type_keywords = OrderedDict([
(KW.temporary , (DestinationType.temporary , 128)), # 32
(KW.a0 , (DestinationType.a0 , 1 )), # ??
(KW.out , (DestinationType.out , 128)), # ?
(KW.out_repl_x , (DestinationType.out_repl_x , 128)), # ?
(KW.alt_temporary , (DestinationType.alt_temporary, 20 )), # 20
(KW.input , (DestinationType.input , 128)), # 32
])
if destination.type_keyword.keyword not in destination_type_keywords:
raise ValidatorError("invalid destination type keyword", destination.type_keyword)
type, max_offset = destination_type_keywords[destination.type_keyword.keyword]
offset = validate_identifier_number(destination.offset_identifier)
if offset >= max_offset:
raise ValidatorError("invalid offset value", source.offset_identifier)
write_enable = validate_destination_write_enable(destination.write_enable_identifier)
return Destination(
type,
offset,
write_enable
)
def parse_swizzle_lexeme(token):
swizzle_select_characters = OrderedDict([
(ord(b"x"), SwizzleSelect.x),
(ord(b"y"), SwizzleSelect.y),
(ord(b"z"), SwizzleSelect.z),
(ord(b"w"), SwizzleSelect.w),
(ord(b"0"), SwizzleSelect.zero),
(ord(b"1"), SwizzleSelect.one),
(ord(b"h"), SwizzleSelect.half),
(ord(b"_"), SwizzleSelect.unused),
])
lexeme = bytes(token.lexeme).lower()
swizzles = []
modifier = False
for c in lexeme:
if c == ord(b"-"):
modifier = True
else:
if c not in swizzle_select_characters:
raise ValueError(c)
swizzles.append((swizzle_select_characters[c], modifier))
modifier = False
return tuple(zip(*swizzles))
def validate_source(source):
source_type_keywords = OrderedDict([
(KW.temporary , (SourceType.temporary , 128)), # 32
(KW.input , (SourceType.input , 128)), # 32
(KW.constant , (SourceType.constant , 256)), # 256
(KW.alt_temporary , (SourceType.alt_temporary, 20)), # 20
])
if source.type_keyword.keyword not in source_type_keywords:
raise ValidatorError("invalid source type keyword", source.type_keyword)
type, max_offset = source_type_keywords[source.type_keyword.keyword]
absolute = source.absolute
offset = validate_identifier_number(source.offset_identifier)
if offset >= max_offset:
raise ValidatorError("invalid offset value", source.offset_identifier)
try:
swizzle_selects, modifiers = parse_swizzle_lexeme(source.swizzle_identifier)
except ValueError:
raise ValidatorError("invalid source swizzle", source.swizzle_identifier)
return Source(
type,
absolute,
offset,
swizzle_selects,
modifiers
)
def addresses_by_type(sources, source_type):
return set(int(source.offset)
for source in sources
if source.type == source_type)
def source_ix_with_type_reversed(sources, source_type):
for i, source in reversed(list(enumerate(sources))):
if source.type == source_type:
return i
assert False, (sources, source_type)
def validate_source_address_counts(sources_ast, sources, opcode):
temporary_address_count = len(addresses_by_type(sources, SourceType.temporary))
assert temporary_address_count >= 0 and temporary_address_count <= 3
assert type(opcode) in {opcodes.VE, opcodes.ME}
if temporary_address_count == 3:
if opcode == opcodes.VE_MULTIPLY_ADD:
opcode = opcodes.MACRO_OP_2CLK_MADD
elif opcode == opcodes.VE_MULTIPLYX2_ADD:
opcode = opcodes.MACRO_OP_2CLK_M2X_ADD
else:
raise ValidatorError("too many temporary addresses in non-macro operation(s)",
sources_ast[-1].offset_identifier)
constant_count = len(addresses_by_type(sources, SourceType.constant))
if constant_count > 1:
source_ix = source_ix_with_type_reversed(sources, SourceType.constant)
raise ValidatorError(f"too many constant addresses in operation(s); expected 1, have {constant_count}",
sources_ast[source_ix].offset_identifier)
input_count = len(addresses_by_type(sources, SourceType.input))
if input_count > 1:
source_ix = source_with_type_reversed(sources, SourceType.input)
raise ValidatorError(f"too many input addresses in operation(s); expected 1, have {input_count}",
sources_ast[source_ix].offset_identifier)
alt_temporary_count = len(addresses_by_type(sources, SourceType.alt_temporary))
if alt_temporary_count > 1:
source_ix = source_with_type_reversed(sources, SourceType.alt_temporary)
raise ValidatorError(f"too many alt temporary addresses in operation(s); expected 1, have {alt_temporary_count}",
sources_ast[source_ix].offset_identifier)
return opcode
def validate_instruction_inner(operation, opcode):
destination = validate_destination(operation.destination)
saturation = False
if operation.opcode_suffix_keyword is not None:
if operation.opcode_suffix_keyword.keyword is not KW.saturation:
raise ValidatorError("invalid opcode saturation suffix", operation.opcode_suffix_keyword)
saturation = True
if len(operation.sources) > 3:
raise ValidatorError("too many sources in operation", operation.sources[0].type_keyword)
if len(operation.sources) != opcode.operand_count:
raise ValidatorError(f"incorrect number of source operands; expected {opcode.operand_count}", operation.sources[0].type_keyword)
sources = []
for source in operation.sources:
sources.append(validate_source(source))
opcode = validate_source_address_counts(operation.sources, sources, opcode)
return Instruction(
destination,
saturation,
sources,
opcode
)
def validate_instruction(ins):
temp_addresses = len(set(
source.offset
for source in [ins.source0, ins.source1, ins.source2]
if (source is not None and source.type == KW.temporary)
))
if temp_addresses > 2:
if type(ins.destination_op.opcode) is not VE:
raise ValidatorError("too many addresses for non-VE instruction", ins)
if ins.destination_op.opcode.name not in {b"VE_MULTIPLYX2_ADD", b"VE_MULTIPLY_ADD"}:
raise ValidatorError("too many addresses for VE non-multiply-add instruction", ins)
assert ins.destination_op.macro == False, ins
ins.destination_op.macro = True
if ins.destination_op.opcode.name == b"VE_MULTIPLY_ADD":
ins.destination_op.opcode = macro_vector_operations[0]
elif ins.destination_op.opcode.name == b"VE_MULTIPLYX2_ADD":
ins.destination_op.opcode = macro_vector_operations[1]
if len(ins.operations) > 2:
raise ValidatorError("too many operations in instruction", ins.operations[0].destination.type_keyword)
opcodes = [validate_opcode(operation.opcode_keyword) for operation in ins.operations]
opcode_types = set(type(opcode) for opcode in opcodes)
if len(opcode_types) != len(opcodes):
opcode_type, = opcode_types
raise ValidatorError(f"invalid dual math operation: too many opcodes of type {opcode_type}", ins.operations[0].opcode_keyword)
if len(opcodes) == 2:
assert False, "not implemented"
#return validate_dual_math_instruction(ins, opcodes)
else:
assert False
return ins
assert len(opcodes) == 1
return validate_instruction_inner(ins.operations[0], opcodes[0])
if __name__ == "__main__":
from assembler.lexer import Lexer
from assembler.parser import ParserError
from assembler.vs.parser import Parser
from assembler.vs.keywords import find_keyword
from assembler.error import print_error
buf = b"""
out[0].xz = VE_MAD.SAT |temp[1].-y-_0-_| const[2].x_0_ const[2].x_0_ ;
"""
#atemp[0].xz = ME_SIN input[0].-y-_-0-_ ;
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=False)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
try:
ins = parser.instruction()
pprint(validate_instruction(ins))
except ValidatorError as e:
print_error(None, buf, e)
raise
except ParserError as e:
print_error(None, buf, e)
raise