assembler: add initial fragment shader parser

This commit is contained in:
Zack Buhman 2025-10-20 12:54:41 -05:00
parent 59b6b2a0d2
commit adca6a1c66
11 changed files with 549 additions and 257 deletions

View File

@ -1,3 +1,3 @@
; CONST[0] = { 1.3333 , _, _, _ } # CONST[0] = { 1.3333 , _, _, _ }
out[1].xy = VE_MUL input[0].xy__ const[0].x1__ out[1].xy = VE_MUL input[0].xy__ const[0].x1__
out[0].xyzw = VE_ADD input[0].xyz1 input[0].0000 out[0].xyzw = VE_ADD input[0].xyz1 input[0].0000

28
regs/assembler/error.py Normal file
View File

@ -0,0 +1,28 @@
import sys
def print_error(filename, buf, e):
assert len(e.args) == 2, e
message, token = e.args
lines = buf.splitlines()
line = lines[token.line - 1]
error_name = str(type(e).__name__)
col_indent = ' ' * token.col
col_pointer = '^' * len(token.lexeme)
RED = "\033[0;31m"
DEFAULT = "\033[0;0m"
print(f'File: "{filename}", line {token.line}, column {token.col}\n', file=sys.stderr)
sys.stderr.write(' ')
wrote_default = False
for i, c in enumerate(line.decode('utf-8')):
if i == token.col:
sys.stderr.write(RED)
sys.stderr.write(c)
if i == token.col + len(token.lexeme):
wrote_default = True
sys.stderr.write(DEFAULT)
if not wrote_default:
sys.stderr.write(DEFAULT)
sys.stderr.write('\n')
print(f" {RED}{col_indent}{col_pointer}{DEFAULT}", file=sys.stderr)
print(f'{RED}{error_name}{DEFAULT}: {message}', file=sys.stderr)

View File

@ -0,0 +1,73 @@
from enum import Enum, auto
class KW(Enum):
# ops
CMP = auto()
CND = auto()
COS = auto()
D2A = auto()
DP = auto()
DP3 = auto()
DP4 = auto()
EX2 = auto()
FRC = auto()
LN2 = auto()
MAD = auto()
MAX = auto()
MDH = auto()
MDV = auto()
MIN = auto()
RCP = auto()
RSQ = auto()
SIN = auto()
SOP = auto()
# source/dest
OUT = auto()
TEMP = auto()
FLOAT = auto()
CONST = auto()
SRC0 = auto()
SRC1 = auto()
SRC2 = auto()
SRCP = auto()
# modifiers
TEX_SEM_WAIT = auto()
_string_to_keyword = {
b"CMP": KW.CMP,
b"CND": KW.CND,
b"COS": KW.COS,
b"D2A": KW.D2A,
b"DP": KW.DP,
b"DP3": KW.DP3,
b"DP4": KW.DP4,
b"EX2": KW.EX2,
b"FRC": KW.FRC,
b"LN2": KW.LN2,
b"MAD": KW.MAD,
b"MAX": KW.MAX,
b"MDH": KW.MDH,
b"MDV": KW.MDV,
b"MIN": KW.MIN,
b"RCP": KW.RCP,
b"RSQ": KW.RSQ,
b"SIN": KW.SIN,
b"SOP": KW.SOP,
b"OUT": KW.OUT,
b"TEMP": KW.TEMP,
b"FLOAT": KW.FLOAT,
b"CONST": KW.CONST,
b"SRC0": KW.SRC0,
b"SRC1": KW.SRC1,
b"SRC2": KW.SRC2,
b"SRCP": KW.SRCP,
b"TEX_SEM_WAIT": KW.TEX_SEM_WAIT,
}
def find_keyword(s):
if s.upper() in _string_to_keyword:
return _string_to_keyword[s.upper()]
else:
return None

191
regs/assembler/fs/parser.py Normal file
View File

@ -0,0 +1,191 @@
from enum import IntEnum
from typing import Literal, Union
from dataclasses import dataclass
from assembler.parser import BaseParser, ParserError
from assembler.lexer import TT, Token
from assembler.fs.keywords import KW, find_keyword
from assembler.error import print_error
class Mod(IntEnum):
nop = 0
neg = 1
abs = 2
nab = 3
@dataclass
class LetExpression:
src_keyword: Token
src_swizzle_identifier: Token
addr_keyword: Token
addr_value_identifier: Token
@dataclass
class DestAddrSwizzle:
dest_keyword: Token
addr_identifier: Token
swizzle_identifier: Token
@dataclass
class SwizzleSel:
sel_keyword: Token
swizzle_identifier: Token
mod: Mod
@dataclass
class Operation:
dest_addr_swizzles: list[DestAddrSwizzle]
opcode_keyword: Token
swizzle_sels: list[SwizzleSel]
@dataclass
class Instruction:
let_expressions: list[LetExpression]
operations: list[Operation]
class Parser(BaseParser):
def let_expression(self):
src_keyword = self.consume(TT.keyword, "expected src keyword")
self.consume(TT.dot, "expected dot")
src_swizzle_identifier = self.consume(TT.identifier, "expected src swizzle identifier")
self.consume(TT.equal, "expected equal")
addr_keyword = self.consume(TT.keyword, "expected addr keyword")
if addr_keyword.keyword is KW.FLOAT:
self.consume(TT.left_paren, "expected left paren")
else:
self.consume(TT.left_square, "expected left square")
addr_value_identifier = self.consume(TT.identifier, "expected address identifier")
if addr_keyword.keyword is KW.FLOAT:
self.consume(TT.right_paren, "expected right paren")
else:
self.consume(TT.right_square, "expected right square")
return LetExpression(
src_keyword,
src_swizzle_identifier,
addr_keyword,
addr_value_identifier,
)
def dest_addr_swizzle(self):
dest_keyword = self.consume(TT.keyword, "expected dest keyword")
self.consume(TT.left_square, "expected left square")
addr_identifier = self.consume(TT.identifier, "expected dest addr identifier")
self.consume(TT.right_square, "expected left square")
self.consume(TT.dot, "expected dot")
swizzle_identifier = self.consume(TT.identifier, "expected dest swizzle identifier")
self.consume(TT.equal, "expected equal")
def is_opcode(self):
opcode_keywords = {
KW.CMP, KW.CND, KW.COS, KW.D2A,
KW.DP , KW.DP3, KW.DP4, KW.EX2,
KW.FRC, KW.LN2, KW.MAD, KW.MAX,
KW.MDH, KW.MDV, KW.MIN, KW.RCP,
KW.RSQ, KW.SIN, KW.SOP,
}
if self.match(TT.keyword):
token = self.peek()
return token.keyword in opcode_keywords
return False
def is_neg(self):
result = self.match(TT.identifier) and self.peek().lexeme == b'-'
if result:
self.advance()
return result
def is_abs(self):
result = self.match(TT.bar)
if result:
self.advance()
return result
def swizzle_sel(self):
neg = self.is_neg()
abs = self.is_abs()
if neg:
self.consume(TT.left_paren, "expected left paren")
sel_keyword = self.consume(TT.keyword, "expected sel keyword")
self.consume(TT.dot, "expected dot")
swizzle_identifier = self.consume(TT.identifier, "expected swizzle identifier")
if abs:
self.consume(TT.bar, "expected bar")
if neg:
self.consume(TT.right_paren, "expected right paren")
mod_table = {
# (neg, abs)
(False, False): Mod.nop,
(False, True): Mod.abs,
(True, False): Mod.neg,
(True, True): Mod.nab,
}
mod = mod_table[(neg, abs)]
return SwizzleSel(
sel_keyword,
swizzle_identifier,
mod,
)
def operation(self):
dest_addr_swizzles = []
while not self.is_opcode():
dest_addr_swizzles.append(self.dest_addr_swizzle())
opcode_keyword = self.consume(TT.keyword, "expected opcode keyword")
swizzle_sels = []
while not (self.match(TT.comma) or self.match(TT.semicolon)):
swizzle_sels.append(self.swizzle_sel())
return Operation(
dest_addr_swizzles,
opcode_keyword,
swizzle_sels
)
def instruction(self):
let_expressions = []
while not self.match(TT.colon):
let_expressions.append(self.let_expression())
if not self.match(TT.colon):
self.consume(TT.comma, "expected comma")
self.consume(TT.colon, "expected colon")
operations = []
while not self.match(TT.semicolon):
operations.append(self.operation())
if not self.match(TT.semicolon):
self.consume(TT.comma, "expected comma")
self.consume(TT.semicolon, "expected semicolon")
return Instruction(
let_expressions,
operations,
)
if __name__ == "__main__":
from assembler.lexer import Lexer
buf = b"""
src0.a = float(0), src0.rgb = temp[0] :
out[0].none = temp[0].none = MAD src0.r src0.r src0.r ,
out[0].none = temp[0].r = DP3 src0.rg0 src0.rg0 ;
"""
lexer = Lexer(buf, find_keyword, emit_newlines=False)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
try:
pprint(parser.instruction())
except ParserError as e:
print_error(None, buf, e)
raise
print(parser.peek())

View File

@ -1,9 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
from itertools import chain from itertools import chain
from typing import Union from typing import Union, Any
from assembler import keywords
DEBUG = True DEBUG = True
@ -18,6 +16,10 @@ class TT(Enum):
dot = auto() dot = auto()
identifier = auto() identifier = auto()
keyword = auto() keyword = auto()
colon = auto()
semicolon = auto()
bar = auto()
comma = auto()
@dataclass @dataclass
class Token: class Token:
@ -26,7 +28,7 @@ class Token:
col: int col: int
type: TT type: TT
lexeme: memoryview lexeme: memoryview
keyword: Union[keywords.VE, keywords.ME, keywords.KW] = None keyword: Any = None
identifier_characters = set(chain( identifier_characters = set(chain(
b'abcdefghijklmnopqrstuvwxyz' b'abcdefghijklmnopqrstuvwxyz'
@ -39,12 +41,14 @@ class LexerError(Exception):
pass pass
class Lexer: class Lexer:
def __init__(self, buf: memoryview): def __init__(self, buf: memoryview, find_keyword, emit_newlines=True):
self.start_ix = 0 self.start_ix = 0
self.current_ix = 0 self.current_ix = 0
self.buf = memoryview(buf) self.buf = memoryview(buf)
self.line = 1 self.line = 1
self.col = 0 self.col = 0
self.find_keyword = find_keyword
self.emit_newlines = emit_newlines
def at_end_p(self): def at_end_p(self):
return self.current_ix >= len(self.buf) return self.current_ix >= len(self.buf)
@ -70,7 +74,7 @@ class Lexer:
def identifier(self): def identifier(self):
while not self.at_end_p() and self.peek() in identifier_characters: while not self.at_end_p() and self.peek() in identifier_characters:
self.advance() self.advance()
keyword = keywords.find_keyword(self.lexeme()) keyword = self.find_keyword(self.lexeme())
if keyword is not None: if keyword is not None:
return Token(*self.pos(), TT.keyword, self.lexeme(), keyword) return Token(*self.pos(), TT.keyword, self.lexeme(), keyword)
else: else:
@ -96,7 +100,15 @@ class Lexer:
return Token(*self.pos(), TT.equal, self.lexeme()) return Token(*self.pos(), TT.equal, self.lexeme())
elif c == ord('.'): elif c == ord('.'):
return Token(*self.pos(), TT.dot, self.lexeme()) return Token(*self.pos(), TT.dot, self.lexeme())
elif c == ord('|'):
return Token(*self.pos(), TT.bar, self.lexeme())
elif c == ord(':'):
return Token(*self.pos(), TT.colon, self.lexeme())
elif c == ord(';'): elif c == ord(';'):
return Token(*self.pos(), TT.semicolon, self.lexeme())
elif c == ord(','):
return Token(*self.pos(), TT.comma, self.lexeme())
elif c == ord('#'):
while not self.at_end_p() and self.peek() != ord('\n'): while not self.at_end_p() and self.peek() != ord('\n'):
self.advance() self.advance()
elif c == ord(' ') or c == ord('\r') or c == ord('\t'): elif c == ord(' ') or c == ord('\r') or c == ord('\t'):
@ -105,11 +117,15 @@ class Lexer:
pos = self.pos() pos = self.pos()
self.line += 1 self.line += 1
self.col = 0 self.col = 0
if self.emit_newlines:
return Token(*pos, TT.eol, self.lexeme()) return Token(*pos, TT.eol, self.lexeme())
else:
continue
elif c in identifier_characters: elif c in identifier_characters:
return self.identifier() return self.identifier()
else: else:
raise LexerError(f"unexpected character at line:{self.line} col:{self.col}") token = Token(*self.pos(), None, self.lexeme())
raise LexerError(f"unexpected character at line:{self.line} col:{self.col}", token)
def lex_tokens(self): def lex_tokens(self):
while True: while True:
@ -119,7 +135,16 @@ class Lexer:
break break
if __name__ == "__main__": if __name__ == "__main__":
def vs_test():
from assembler.vskeywords import find_keyword
test = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" test = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(test) lexer = Lexer(test, find_keyword)
for token in lexer.lex_tokens(): for token in lexer.lex_tokens():
print(token) print(token)
def fs_test():
from assembler.fs.keywords import find_keyword
test = b"src0.rgb = temp[0] : temp[0].a = OP_RSQ |src0.r| ;"
lexer = Lexer(test, find_keyword)
for token in lexer.lex_tokens():
print(token)
fs_test()

View File

@ -1,125 +1,15 @@
from itertools import pairwise from typing import Any
from dataclasses import dataclass
from typing import Union
from assembler import lexer
from assembler.lexer import TT, Token
from assembler.keywords import KW, ME, VE
"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
@dataclass
class DestinationOp:
type: KW
offset: int
write_enable: set[int]
opcode: Union[VE, ME]
sat: bool
macro: bool
@dataclass
class SourceSwizzle:
select: tuple[int, int, int, int]
modifier: tuple[bool, bool, bool, bool]
@dataclass
class Source:
type: KW
offset: int
swizzle: SourceSwizzle
@dataclass
class Instruction:
destination_op: DestinationOp
source0: Source
source1: Source
source2: Source
class ParserError(Exception): class ParserError(Exception):
pass pass
def identifier_to_number(token): class BaseParser:
digits = set(b"0123456789") def __init__(self, tokens: list[Any]):
assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme):
raise ParserError("expected number", token)
return int(bytes(token.lexeme), 10)
def we_ord(c):
if c == ord("w"):
return 3
else:
return c - ord("x")
def parse_dest_write_enable(token):
we_chars = set(b"xyzw")
assert token.type is TT.identifier
we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("expected destination write enable", token)
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential write enable", token)
return set(we_ord(c) for c in we)
def parse_source_swizzle(token):
select_mapping = {
ord('x'): 0,
ord('y'): 1,
ord('z'): 2,
ord('w'): 3,
ord('0'): 4,
ord('1'): 5,
ord('h'): 6,
ord('_'): 7,
ord('u'): 7,
}
state = 0
ix = 0
swizzle_selects = [None] * 4
swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower()
while state < 4:
if ix >= len(token.lexeme):
raise ParserError("invalid source swizzle", token)
c = lexeme[ix]
if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParserError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True
elif c in select_mapping:
if swizzle_selects[state] is not None:
raise ParserError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False
state += 1
else:
raise ParserError("invalid source swizzle", token)
ix += 1
if ix != len(lexeme):
raise ParserError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
class Parser:
def __init__(self, tokens: list[lexer.Token]):
self.current_ix = 0 self.current_ix = 0
self.tokens = tokens self.tokens = tokens
def peek(self, offset=0): def peek(self, offset=0):
token = self.tokens[self.current_ix + offset] token = self.tokens[self.current_ix + offset]
#print(token)
return token return token
def at_end_p(self): def at_end_p(self):
@ -145,100 +35,3 @@ class Parser:
if token.type != token_type1 and token.type != token_type2: if token.type != token_type1 and token.type != token_type2:
raise ParserError(message, token) raise ParserError(message, token)
return token return token
def destination_type(self):
token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
if token.keyword not in destination_keywords:
raise ParserError("expected destination type", token)
return token.keyword
def offset(self):
self.consume(TT.left_square, "expected offset")
identifier_token = self.consume(TT.identifier, "expected offset")
value = identifier_to_number(identifier_token)
self.consume(TT.right_square, "expected offset")
return value
def opcode(self):
token = self.consume(TT.keyword, "expected opcode")
if type(token.keyword) != VE and type(token.keyword) != ME:
raise ParserError("expected opcode", token)
return token.keyword
def destination_op(self):
destination_type = self.destination_type()
offset_value = self.offset()
self.consume(TT.dot, "expected write enable")
write_enable_token = self.consume(TT.identifier, "expected write enable token")
write_enable = parse_dest_write_enable(write_enable_token)
self.consume(TT.equal, "expected equals")
opcode = self.opcode()
sat = False
if self.match(TT.dot):
self.advance()
suffix = self.consume(TT.keyword, "expected saturation suffix")
if suffix.keyword is not KW.saturation:
raise ParserError("expected saturation suffix", token)
sat = True
macro = False
return DestinationOp(type=destination_type,
offset=offset_value,
write_enable=write_enable,
opcode=opcode,
sat=sat,
macro=macro)
def source_type(self):
token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords:
raise ParserError("expected source type", token)
return token.keyword
def source_swizzle(self):
token = self.consume(TT.identifier, "expected source swizzle")
return parse_source_swizzle(token)
def source(self):
"input[0].-y-_-0-_"
source_type = self.source_type()
offset = self.offset()
self.consume(TT.dot, "expected source swizzle")
source_swizzle = self.source_swizzle()
return Source(source_type, offset, source_swizzle)
def instruction(self):
while self.match(TT.eol):
self.advance()
first_token = self.peek()
destination_op = self.destination_op()
source0 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source1 = None
else:
source1 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source2 = None
else:
source2 = self.source()
last_token = self.peek(-1)
self.consume_either(TT.eol, TT.eof, "expected newline or EOF")
return (
Instruction(destination_op, source0, source1, source2),
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
)
def instructions(self):
while not self.match(TT.eof):
yield self.instruction()
if __name__ == "__main__":
from assembler.lexer import Lexer
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(buf)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
pprint(parser.instruction())

View File

@ -1,9 +1,10 @@
import sys import sys
from assembler.lexer import Lexer, LexerError from assembler.lexer import Lexer, LexerError
from assembler.parser import Parser, ParserError from assembler.vs.keywords import find_keyword
from assembler.emitter import emit_instruction from assembler.vs.parser import Parser, ParserError
from assembler.validator import validate_instruction from assembler.vs.emitter import emit_instruction
from assembler.vs.validator import validate_instruction
sample = b""" sample = b"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
@ -19,40 +20,13 @@ out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
""" """
def frontend_inner(buf): def frontend_inner(buf):
lexer = Lexer(buf) lexer = Lexer(buf, find_keyword)
tokens = list(lexer.lex_tokens()) tokens = list(lexer.lex_tokens())
parser = Parser(tokens) parser = Parser(tokens)
for ins, start_end in parser.instructions(): for ins, start_end in parser.instructions():
ins = validate_instruction(ins) ins = validate_instruction(ins)
yield list(emit_instruction(ins)), start_end yield list(emit_instruction(ins)), start_end
def print_error(filename, buf, e):
assert len(e.args) == 2, e
message, token = e.args
lines = buf.splitlines()
line = lines[token.line - 1]
error_name = str(type(e).__name__)
col_indent = ' ' * token.col
col_pointer = '^' * len(token.lexeme)
RED = "\033[0;31m"
DEFAULT = "\033[0;0m"
print(f'File: "{filename}", line {token.line}, column {token.col}\n', file=sys.stderr)
sys.stderr.write(' ')
wrote_default = False
for i, c in enumerate(line.decode('utf-8')):
if i == token.col:
sys.stderr.write(RED)
sys.stderr.write(c)
if i == token.col + len(token.lexeme):
wrote_default = True
sys.stderr.write(DEFAULT)
if not wrote_default:
sys.stderr.write(DEFAULT)
sys.stderr.write('\n')
print(f" {RED}{col_indent}{col_pointer}{DEFAULT}", file=sys.stderr)
print(f'{RED}{error_name}{DEFAULT}: {message}', file=sys.stderr)
def frontend(filename, buf): def frontend(filename, buf):
try: try:
yield from frontend_inner(buf) yield from frontend_inner(buf)

View File

@ -1,5 +1,5 @@
from assembler.keywords import ME, VE, MVE, KW from assembler.vs.keywords import ME, VE, MVE, KW
from assembler.parser import Instruction, DestinationOp, Source from assembler.vs.parser import Instruction, DestinationOp, Source
import pvs_dst import pvs_dst
import pvs_src import pvs_src
import pvs_dst_bits import pvs_dst_bits

208
regs/assembler/vs/parser.py Normal file
View File

@ -0,0 +1,208 @@
from itertools import pairwise
from dataclasses import dataclass
from typing import Union
from assembler.parser import BaseParser, ParserError
from assembler.lexer import TT
from assembler.vs.keywords import KW, ME, VE, find_keyword
"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
@dataclass
class DestinationOp:
type: KW
offset: int
write_enable: set[int]
opcode: Union[VE, ME]
sat: bool
macro: bool
@dataclass
class SourceSwizzle:
select: tuple[int, int, int, int]
modifier: tuple[bool, bool, bool, bool]
@dataclass
class Source:
type: KW
offset: int
swizzle: SourceSwizzle
@dataclass
class Instruction:
destination_op: DestinationOp
source0: Source
source1: Source
source2: Source
def identifier_to_number(token):
digits = set(b"0123456789")
assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme):
raise ParserError("expected number", token)
return int(bytes(token.lexeme), 10)
def we_ord(c):
if c == ord("w"):
return 3
else:
return c - ord("x")
def parse_dest_write_enable(token):
we_chars = set(b"xyzw")
assert token.type is TT.identifier
we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("expected destination write enable", token)
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential write enable", token)
return set(we_ord(c) for c in we)
def parse_source_swizzle(token):
select_mapping = {
ord('x'): 0,
ord('y'): 1,
ord('z'): 2,
ord('w'): 3,
ord('0'): 4,
ord('1'): 5,
ord('h'): 6,
ord('_'): 7,
ord('u'): 7,
}
state = 0
ix = 0
swizzle_selects = [None] * 4
swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower()
while state < 4:
if ix >= len(token.lexeme):
raise ParserError("invalid source swizzle", token)
c = lexeme[ix]
if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParserError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True
elif c in select_mapping:
if swizzle_selects[state] is not None:
raise ParserError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False
state += 1
else:
raise ParserError("invalid source swizzle", token)
ix += 1
if ix != len(lexeme):
raise ParserError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
class Parser(BaseParser):
def destination_type(self):
token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
if token.keyword not in destination_keywords:
raise ParserError("expected destination type", token)
return token.keyword
def offset(self):
self.consume(TT.left_square, "expected offset")
identifier_token = self.consume(TT.identifier, "expected offset")
value = identifier_to_number(identifier_token)
self.consume(TT.right_square, "expected offset")
return value
def opcode(self):
token = self.consume(TT.keyword, "expected opcode")
if type(token.keyword) != VE and type(token.keyword) != ME:
raise ParserError("expected opcode", token)
return token.keyword
def destination_op(self):
destination_type = self.destination_type()
offset_value = self.offset()
self.consume(TT.dot, "expected write enable")
write_enable_token = self.consume(TT.identifier, "expected write enable token")
write_enable = parse_dest_write_enable(write_enable_token)
self.consume(TT.equal, "expected equals")
opcode = self.opcode()
sat = False
if self.match(TT.dot):
self.advance()
suffix = self.consume(TT.keyword, "expected saturation suffix")
if suffix.keyword is not KW.saturation:
raise ParserError("expected saturation suffix", token)
sat = True
macro = False
return DestinationOp(type=destination_type,
offset=offset_value,
write_enable=write_enable,
opcode=opcode,
sat=sat,
macro=macro)
def source_type(self):
token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords:
raise ParserError("expected source type", token)
return token.keyword
def source_swizzle(self):
token = self.consume(TT.identifier, "expected source swizzle")
return parse_source_swizzle(token)
def source(self):
"input[0].-y-_-0-_"
source_type = self.source_type()
offset = self.offset()
self.consume(TT.dot, "expected source swizzle")
source_swizzle = self.source_swizzle()
return Source(source_type, offset, source_swizzle)
def instruction(self):
while self.match(TT.eol):
self.advance()
first_token = self.peek()
destination_op = self.destination_op()
source0 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source1 = None
else:
source1 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source2 = None
else:
source2 = self.source()
last_token = self.peek(-1)
self.consume_either(TT.eol, TT.eof, "expected newline or EOF")
return (
Instruction(destination_op, source0, source1, source2),
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
)
def instructions(self):
while not self.match(TT.eof):
yield self.instruction()
if __name__ == "__main__":
from assembler.lexer import Lexer
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(buf, find_keyword)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
pprint(parser.instruction())

View File

@ -1,4 +1,4 @@
from assembler.keywords import ME, VE, macro_vector_operations from assembler.vs.keywords import ME, VE, macro_vector_operations
class ValidatorError(Exception): class ValidatorError(Exception):
pass pass