regs: incomplete pvs assembler

This commit is contained in:
Zack Buhman 2025-10-16 00:33:00 -05:00
parent a5cb458c1f
commit e24b3ada5e
3 changed files with 454 additions and 0 deletions

119
regs/assembler/keywords.py Normal file
View File

@ -0,0 +1,119 @@
from dataclasses import dataclass
from typing import Optional
from enum import Enum, auto
@dataclass
class VE:
name: str
synonym: Optional[str]
value: int
@dataclass
class ME:
name: str
synonym: Optional[str]
value: int
vector_operations = [
# name synonym value
VE(b"VECTOR_NO_OP" , b"VE_NOP" , 0),
VE(b"VE_DOT_PRODUCT" , b"VE_DOT" , 1),
VE(b"VE_MULTIPLY" , b"VE_MUL" , 2),
VE(b"VE_ADD" , None , 3),
VE(b"VE_MULTIPLY_ADD" , b"VE_MAD" , 4),
VE(b"VE_DISTANCE_VECTOR" , None , 5),
VE(b"VE_FRACTION" , b"VE_FRC" , 6),
VE(b"VE_MAXIMUM" , b"VE_MAX" , 7),
VE(b"VE_MINIMUM" , b"VE_MIN" , 8),
VE(b"VE_SET_GREATER_THAN_EQUAL" , b"VE_SGE" , 9),
VE(b"VE_SET_LESS_THAN" , b"VE_SLT" , 10),
VE(b"VE_MULTIPLYX2_ADD" , None , 11),
VE(b"VE_MULTIPLY_CLAMP" , None , 12),
VE(b"VE_FLT2FIX_DX" , None , 13),
VE(b"VE_FLT2FIX_DX_RND" , None , 14),
VE(b"VE_PRED_SET_EQ_PUSH" , None , 15),
VE(b"VE_PRED_SET_GT_PUSH" , None , 16),
VE(b"VE_PRED_SET_GTE_PUSH" , None , 17),
VE(b"VE_PRED_SET_NEQ_PUSH" , None , 18),
VE(b"VE_COND_WRITE_EQ" , None , 19),
VE(b"VE_COND_WRITE_GT" , None , 20),
VE(b"VE_COND_WRITE_GTE" , None , 21),
VE(b"VE_COND_WRITE_NEQ" , None , 22),
VE(b"VE_COND_MUX_EQ" , None , 23),
VE(b"VE_COND_MUX_GT" , None , 24),
VE(b"VE_COND_MUX_GTE" , None , 25),
VE(b"VE_SET_GREATER_THAN" , b"VE_SGT" , 26),
VE(b"VE_SET_EQUAL" , b"VE_SEQ" , 27),
VE(b"VE_SET_NOT_EQUAL" , b"VE_SNE" , 28),
]
math_operations = [
# name synonym value
ME(b"MATH_NO_OP" , b"ME_NOP" , 0),
ME(b"ME_EXP_BASE2_DX" , b"ME_EXP" , 1),
ME(b"ME_LOG_BASE2_DX" , b"ME_LOG2", 2),
ME(b"ME_EXP_BASEE_FF" , b"ME_EXPE", 3),
ME(b"ME_LIGHT_COEFF_DX" , None , 4),
ME(b"ME_POWER_FUNC_FF" , b"ME_POW" , 5),
ME(b"ME_RECIP_DX" , b"ME_RCP" , 6),
ME(b"ME_RECIP_FF" , None , 7),
ME(b"ME_RECIP_SQRT_DX" , b"ME_RSQ" , 8),
ME(b"ME_RECIP_SQRT_FF" , None , 9),
ME(b"ME_MULTIPLY" , b"ME_MUL" , 10),
ME(b"ME_EXP_BASE2_FULL_DX" , None , 11),
ME(b"ME_LOG_BASE2_FULL_DX" , None , 12),
ME(b"ME_POWER_FUNC_FF_CLAMP_B" , None , 13),
ME(b"ME_POWER_FUNC_FF_CLAMP_B1" , None , 14),
ME(b"ME_POWER_FUNC_FF_CLAMP_01" , None , 15),
ME(b"ME_SIN" , None , 16),
ME(b"ME_COS" , None , 17),
ME(b"ME_LOG_BASE2_IEEE" , None , 18),
ME(b"ME_RECIP_IEEE" , None , 19),
ME(b"ME_RECIP_SQRT_IEEE" , None , 20),
ME(b"ME_PRED_SET_EQ" , None , 21),
ME(b"ME_PRED_SET_GT" , None , 22),
ME(b"ME_PRED_SET_GTE" , None , 23),
ME(b"ME_PRED_SET_NEQ" , None , 24),
ME(b"ME_PRED_SET_CLR" , None , 25),
ME(b"ME_PRED_SET_INV" , None , 26),
ME(b"ME_PRED_SET_POP" , None , 27),
ME(b"ME_PRED_SET_RESTORE" , None , 28),
]
class KW(Enum):
temporary = auto()
a0 = auto()
out = auto()
out_repl_x = auto()
alt_temporary = auto()
input = auto()
absolute = auto()
relative_a0 = auto()
relative_i0 = auto()
constant = auto()
keywords = [
(KW.temporary , b"temporary" , b"temp"),
(KW.a0 , b"a0" , None),
(KW.out , b"out" , None),
(KW.out_repl_x , b"out_repl_x" , None),
(KW.alt_temporary , b"alt_temporary" , b"alt_temp"),
(KW.input , b"input" , None),
(KW.absolute , b"absolute" , None),
(KW.relative_a0 , b"relative_a0" , None),
(KW.relative_i0 , b"relative_i0" , None),
(KW.constant , b"constant" , b"const"),
]
def find_keyword(b: memoryview):
b = bytes(b)
for vector_op in vector_operations:
if vector_op.name == b.upper() or (vector_op.synonym is not None and vector_op.synonym == b.upper()):
return vector_op
for math_op in math_operations:
if math_op.name == b.upper() or (math_op.synonym is not None and math_op.synonym == b.upper()):
return math_op
for keyword, name, synonym in keywords:
if name == b.lower() or (synonym is not None and synonym == b.lower()):
return keyword
return None

124
regs/assembler/lexer.py Normal file
View File

@ -0,0 +1,124 @@
from dataclasses import dataclass
from enum import Enum, auto
from itertools import chain
from typing import Union
import keywords
DEBUG = True
class TT(Enum):
eof = auto()
eol = auto()
left_square = auto()
right_square = auto()
left_paren = auto()
right_paren = auto()
equal = auto()
dot = auto()
identifier = auto()
keyword = auto()
@dataclass
class Token:
line: int
col: int
type: TT
lexeme: memoryview
keyword: Union[keywords.VE, keywords.ME, keywords.KW] = None
identifier_characters = set(chain(
b'abcdefghijklmnopqrstuvwxyz'
b'ABCDEFGHIJKLMNOPQRSTUVWXYZ',
b'0123456789',
b'_-'
))
class LexerError(Exception):
pass
class Lexer:
def __init__(self, buf: memoryview):
self.start_ix = 0
self.current_ix = 0
self.buf = memoryview(buf)
self.line = 1
self.col = 0
def at_end_p(self):
return self.current_ix >= len(self.buf)
def lexeme(self):
if DEBUG:
return bytes(self.buf[self.start_ix:self.current_ix])
else:
return memoryview(self.buf[self.start_ix:self.current_ix])
def advance(self):
c = self.buf[self.current_ix]
self.col += 1
self.current_ix += 1
return c
def peek(self):
return self.buf[self.current_ix]
def pos(self):
return self.line, self.col - (self.current_ix - self.start_ix)
def identifier(self):
while not self.at_end_p() and self.peek() in identifier_characters:
self.advance()
keyword = keywords.find_keyword(self.lexeme())
if keyword is not None:
return Token(*self.pos(), TT.keyword, self.lexeme(), keyword)
else:
return Token(*self.pos(), TT.identifier, self.lexeme(), None)
def lex_token(self):
while True:
self.start_ix = self.current_ix
if self.at_end_p():
return Token(*self.pos(), TT.eof, self.lexeme())
c = self.advance()
if c == ord('('):
return Token(*self.pos(), TT.left_paren, self.lexeme())
elif c == ord(')'):
return Token(*self.pos(), TT.right_paren, self.lexeme())
elif c == ord('['):
return Token(*self.pos(), TT.left_square, self.lexeme())
elif c == ord(']'):
return Token(*self.pos(), TT.right_square, self.lexeme())
elif c == ord('='):
return Token(*self.pos(), TT.equal, self.lexeme())
elif c == ord('.'):
return Token(*self.pos(), TT.dot, self.lexeme())
elif c == ord(';'):
while not at_end_p() and peek() != ord('\n'):
self.advance()
elif c == ord(' ') or c == ord('\r') or c == ord('\t'):
pass
elif c == ord('\n'):
pos = self.pos()
self.line += 1
self.col = 0
return Token(*pos, TT.eol, self.lexeme())
elif c in identifier_characters:
return self.identifier()
else:
raise LexerError(f"unexpected character at line:{self.line} col:{self.col}")
def lex_tokens(self):
while True:
token = self.lex_token()
yield token
if token.type is TT.eof:
break
if __name__ == "__main__":
test = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(test)
for token in lexer.lex_tokens():
print(token)

211
regs/assembler/parser.py Normal file
View File

@ -0,0 +1,211 @@
import lexer
from lexer import TT
from keywords import KW, ME, VE
from itertools import pairwise
from dataclasses import dataclass
from typing import Union
"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRAC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
@dataclass
class DestinationOp:
type: KW
offset: int
write_enable: set[int]
opcode: Union[VE, ME]
@dataclass
class SourceSwizzle:
select: tuple[int, int, int, int]
modifier: tuple[bool, bool, bool, bool]
@dataclass
class Source:
type: KW
offset: int
swizzle: SourceSwizzle
@dataclass
class Instruction:
destination_op: DestinationOp
source0: Source
source1: Source
source2: Source
class ParseError(Exception):
pass
def identifier_to_number(token):
digits = set(b"0123456789")
assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme):
raise ParseError("expected number", token)
return int(bytes(token.lexeme), 10)
def parse_dest_write_enable(token):
we_chars = set(b"xyzw")
assert token.type is TT.identifier
we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParseError("expected destination write enable", token)
if not all(a < b for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParseError("misleading non-sequential write enable", token)
return set(c - ord('x') for c in we)
def parse_source_swizzle(token):
select_mapping = {
ord('x'): 0,
ord('y'): 1,
ord('z'): 2,
ord('w'): 3,
ord('0'): 4,
ord('1'): 5,
ord('h'): 6,
ord('_'): 7,
ord('u'): 7,
}
state = 0
ix = 0
swizzle_selects = [None] * 4
swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower()
while state < 4:
if ix > len(token.lexeme):
raise ParseError("invalid source swizzle", token)
c = lexeme[ix]
if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParseError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True
elif c in select_mapping:
if swizzle_selects[state] is not None:
raise ParseError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False
state += 1
else:
raise ParseError("invalid source swizzle", token)
ix += 1
if ix != len(lexeme):
raise ParseError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
class Parser:
def __init__(self, tokens: list[lexer.Token]):
self.current_ix = 0
self.tokens = tokens
def peek(self):
return self.tokens[self.current_ix]
def at_end_p(self):
return self.peek().type == TT.eof
def advance(self):
token = self.peek()
self.current_ix += 1
return token
def match(self, token_type, message):
token = self.advance()
return token.type == token_type
def consume(self, token_type, message):
token = self.advance()
if token.type != token_type:
raise ParseError(message, token)
return token
def consume_either(self, token_type1, token_type2, message):
token = self.advance()
if token.type != token_type1 and token.type != token_type2:
raise ParseError(message, token)
return token
"""
def consume_keyword(self, keyword, message):
token = self.consume(TT.keyword, message)
assert token.keyword is not None
if token.keyword != keyword:
raise ParseError(message, token)
"""
def destination_type(self):
token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
if token.keyword not in destination_keywords:
raise ParseError("expected destination type", token)
return token.keyword
def offset(self):
self.consume(TT.left_square, "expected offset")
identifier_token = self.consume(TT.identifier, "expected offset")
value = identifier_to_number(identifier_token)
self.consume(TT.right_square, "expected offset")
return value
def opcode(self):
token = self.consume(TT.keyword, "expected opcode")
if type(token.keyword) != VE and type(token.keyword) != ME:
raise ParseError("expected opcode", token)
return token.keyword
def destination_op(self):
destination_type = self.destination_type()
offset_value = self.offset()
self.consume(TT.dot, "expected write enable")
write_enable_token = self.consume(TT.identifier, "expected write enable token")
write_enable = parse_dest_write_enable(write_enable_token)
self.consume(TT.equal, "expected equals")
opcode = self.opcode()
return DestinationOp(destination_type, offset_value, write_enable, opcode)
def source_type(self):
token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords:
raise ParseError("expected source type", token)
return token.keyword
def source_swizzle(self):
token = self.consume(TT.identifier, "expected source swizzle")
return parse_source_swizzle(token)
def source(self):
"input[0].-y-_-0-_"
source_type = self.source_type()
offset = self.offset()
self.consume(TT.dot, "expected source swizzle")
source_swizzle = self.source_swizzle()
return Source(source_type, offset, source_swizzle)
def instruction(self):
destination_op = self.destination_op()
source0 = self.source()
source1 = self.source()
source2 = self.source()
self.consume_either(TT.eol, TT.eof, "expected newline or EOF")
return Instruction(destination_op, source0, source1, source2)
if __name__ == "__main__":
from lexer import Lexer
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(buf)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
pprint(parser.instruction())