From adca6a1c66ab734b0b318f3503b7ae37dfc94639 Mon Sep 17 00:00:00 2001 From: Zack Buhman Date: Mon, 20 Oct 2025 12:54:41 -0500 Subject: [PATCH] assembler: add initial fragment shader parser --- drm/shadertoy.vs.asm | 2 +- regs/assembler/error.py | 28 ++++ regs/assembler/fs/keywords.py | 73 +++++++++ regs/assembler/fs/parser.py | 191 ++++++++++++++++++++++++ regs/assembler/lexer.py | 49 ++++-- regs/assembler/parser.py | 213 +-------------------------- regs/assembler/{ => vs}/__main__.py | 36 +---- regs/assembler/{ => vs}/emitter.py | 4 +- regs/assembler/{ => vs}/keywords.py | 0 regs/assembler/vs/parser.py | 208 ++++++++++++++++++++++++++ regs/assembler/{ => vs}/validator.py | 2 +- 11 files changed, 549 insertions(+), 257 deletions(-) create mode 100644 regs/assembler/error.py create mode 100644 regs/assembler/fs/keywords.py create mode 100644 regs/assembler/fs/parser.py rename regs/assembler/{ => vs}/__main__.py (61%) rename regs/assembler/{ => vs}/emitter.py (97%) rename regs/assembler/{ => vs}/keywords.py (100%) create mode 100644 regs/assembler/vs/parser.py rename regs/assembler/{ => vs}/validator.py (93%) diff --git a/drm/shadertoy.vs.asm b/drm/shadertoy.vs.asm index 745be4e..89e4513 100644 --- a/drm/shadertoy.vs.asm +++ b/drm/shadertoy.vs.asm @@ -1,3 +1,3 @@ -; CONST[0] = { 1.3333 , _, _, _ } +# CONST[0] = { 1.3333 , _, _, _ } out[1].xy = VE_MUL input[0].xy__ const[0].x1__ out[0].xyzw = VE_ADD input[0].xyz1 input[0].0000 diff --git a/regs/assembler/error.py b/regs/assembler/error.py new file mode 100644 index 0000000..b8122f3 --- /dev/null +++ b/regs/assembler/error.py @@ -0,0 +1,28 @@ +import sys + +def print_error(filename, buf, e): + assert len(e.args) == 2, e + message, token = e.args + lines = buf.splitlines() + line = lines[token.line - 1] + + error_name = str(type(e).__name__) + col_indent = ' ' * token.col + col_pointer = '^' * len(token.lexeme) + RED = "\033[0;31m" + DEFAULT = "\033[0;0m" + print(f'File: "{filename}", line {token.line}, column {token.col}\n', file=sys.stderr) + sys.stderr.write(' ') + wrote_default = False + for i, c in enumerate(line.decode('utf-8')): + if i == token.col: + sys.stderr.write(RED) + sys.stderr.write(c) + if i == token.col + len(token.lexeme): + wrote_default = True + sys.stderr.write(DEFAULT) + if not wrote_default: + sys.stderr.write(DEFAULT) + sys.stderr.write('\n') + print(f" {RED}{col_indent}{col_pointer}{DEFAULT}", file=sys.stderr) + print(f'{RED}{error_name}{DEFAULT}: {message}', file=sys.stderr) diff --git a/regs/assembler/fs/keywords.py b/regs/assembler/fs/keywords.py new file mode 100644 index 0000000..2f238a9 --- /dev/null +++ b/regs/assembler/fs/keywords.py @@ -0,0 +1,73 @@ +from enum import Enum, auto + +class KW(Enum): + # ops + CMP = auto() + CND = auto() + COS = auto() + D2A = auto() + DP = auto() + DP3 = auto() + DP4 = auto() + EX2 = auto() + FRC = auto() + LN2 = auto() + MAD = auto() + MAX = auto() + MDH = auto() + MDV = auto() + MIN = auto() + RCP = auto() + RSQ = auto() + SIN = auto() + SOP = auto() + + # source/dest + OUT = auto() + TEMP = auto() + FLOAT = auto() + CONST = auto() + SRC0 = auto() + SRC1 = auto() + SRC2 = auto() + SRCP = auto() + + # modifiers + TEX_SEM_WAIT = auto() + +_string_to_keyword = { + b"CMP": KW.CMP, + b"CND": KW.CND, + b"COS": KW.COS, + b"D2A": KW.D2A, + b"DP": KW.DP, + b"DP3": KW.DP3, + b"DP4": KW.DP4, + b"EX2": KW.EX2, + b"FRC": KW.FRC, + b"LN2": KW.LN2, + b"MAD": KW.MAD, + b"MAX": KW.MAX, + b"MDH": KW.MDH, + b"MDV": KW.MDV, + b"MIN": KW.MIN, + b"RCP": KW.RCP, + b"RSQ": KW.RSQ, + b"SIN": KW.SIN, + b"SOP": KW.SOP, + b"OUT": KW.OUT, + b"TEMP": KW.TEMP, + b"FLOAT": KW.FLOAT, + b"CONST": KW.CONST, + b"SRC0": KW.SRC0, + b"SRC1": KW.SRC1, + b"SRC2": KW.SRC2, + b"SRCP": KW.SRCP, + b"TEX_SEM_WAIT": KW.TEX_SEM_WAIT, +} + +def find_keyword(s): + if s.upper() in _string_to_keyword: + return _string_to_keyword[s.upper()] + else: + return None diff --git a/regs/assembler/fs/parser.py b/regs/assembler/fs/parser.py new file mode 100644 index 0000000..ad9afa4 --- /dev/null +++ b/regs/assembler/fs/parser.py @@ -0,0 +1,191 @@ +from enum import IntEnum +from typing import Literal, Union +from dataclasses import dataclass + +from assembler.parser import BaseParser, ParserError +from assembler.lexer import TT, Token +from assembler.fs.keywords import KW, find_keyword +from assembler.error import print_error + +class Mod(IntEnum): + nop = 0 + neg = 1 + abs = 2 + nab = 3 + +@dataclass +class LetExpression: + src_keyword: Token + src_swizzle_identifier: Token + addr_keyword: Token + addr_value_identifier: Token + +@dataclass +class DestAddrSwizzle: + dest_keyword: Token + addr_identifier: Token + swizzle_identifier: Token + +@dataclass +class SwizzleSel: + sel_keyword: Token + swizzle_identifier: Token + mod: Mod + +@dataclass +class Operation: + dest_addr_swizzles: list[DestAddrSwizzle] + opcode_keyword: Token + swizzle_sels: list[SwizzleSel] + +@dataclass +class Instruction: + let_expressions: list[LetExpression] + operations: list[Operation] + +class Parser(BaseParser): + def let_expression(self): + src_keyword = self.consume(TT.keyword, "expected src keyword") + self.consume(TT.dot, "expected dot") + src_swizzle_identifier = self.consume(TT.identifier, "expected src swizzle identifier") + self.consume(TT.equal, "expected equal") + addr_keyword = self.consume(TT.keyword, "expected addr keyword") + if addr_keyword.keyword is KW.FLOAT: + self.consume(TT.left_paren, "expected left paren") + else: + self.consume(TT.left_square, "expected left square") + + addr_value_identifier = self.consume(TT.identifier, "expected address identifier") + + if addr_keyword.keyword is KW.FLOAT: + self.consume(TT.right_paren, "expected right paren") + else: + self.consume(TT.right_square, "expected right square") + + return LetExpression( + src_keyword, + src_swizzle_identifier, + addr_keyword, + addr_value_identifier, + ) + + def dest_addr_swizzle(self): + dest_keyword = self.consume(TT.keyword, "expected dest keyword") + self.consume(TT.left_square, "expected left square") + addr_identifier = self.consume(TT.identifier, "expected dest addr identifier") + self.consume(TT.right_square, "expected left square") + self.consume(TT.dot, "expected dot") + swizzle_identifier = self.consume(TT.identifier, "expected dest swizzle identifier") + self.consume(TT.equal, "expected equal") + + def is_opcode(self): + opcode_keywords = { + KW.CMP, KW.CND, KW.COS, KW.D2A, + KW.DP , KW.DP3, KW.DP4, KW.EX2, + KW.FRC, KW.LN2, KW.MAD, KW.MAX, + KW.MDH, KW.MDV, KW.MIN, KW.RCP, + KW.RSQ, KW.SIN, KW.SOP, + } + if self.match(TT.keyword): + token = self.peek() + return token.keyword in opcode_keywords + return False + + def is_neg(self): + result = self.match(TT.identifier) and self.peek().lexeme == b'-' + if result: + self.advance() + return result + + def is_abs(self): + result = self.match(TT.bar) + if result: + self.advance() + return result + + def swizzle_sel(self): + neg = self.is_neg() + abs = self.is_abs() + + if neg: + self.consume(TT.left_paren, "expected left paren") + + sel_keyword = self.consume(TT.keyword, "expected sel keyword") + self.consume(TT.dot, "expected dot") + swizzle_identifier = self.consume(TT.identifier, "expected swizzle identifier") + + if abs: + self.consume(TT.bar, "expected bar") + if neg: + self.consume(TT.right_paren, "expected right paren") + + mod_table = { + # (neg, abs) + (False, False): Mod.nop, + (False, True): Mod.abs, + (True, False): Mod.neg, + (True, True): Mod.nab, + } + mod = mod_table[(neg, abs)] + return SwizzleSel( + sel_keyword, + swizzle_identifier, + mod, + ) + + def operation(self): + dest_addr_swizzles = [] + while not self.is_opcode(): + dest_addr_swizzles.append(self.dest_addr_swizzle()) + + opcode_keyword = self.consume(TT.keyword, "expected opcode keyword") + + swizzle_sels = [] + while not (self.match(TT.comma) or self.match(TT.semicolon)): + swizzle_sels.append(self.swizzle_sel()) + + return Operation( + dest_addr_swizzles, + opcode_keyword, + swizzle_sels + ) + + def instruction(self): + let_expressions = [] + while not self.match(TT.colon): + let_expressions.append(self.let_expression()) + if not self.match(TT.colon): + self.consume(TT.comma, "expected comma") + + self.consume(TT.colon, "expected colon") + + operations = [] + while not self.match(TT.semicolon): + operations.append(self.operation()) + if not self.match(TT.semicolon): + self.consume(TT.comma, "expected comma") + + self.consume(TT.semicolon, "expected semicolon") + + return Instruction( + let_expressions, + operations, + ) + +if __name__ == "__main__": + from assembler.lexer import Lexer + buf = b""" +src0.a = float(0), src0.rgb = temp[0] : + out[0].none = temp[0].none = MAD src0.r src0.r src0.r , + out[0].none = temp[0].r = DP3 src0.rg0 src0.rg0 ; + """ + lexer = Lexer(buf, find_keyword, emit_newlines=False) + tokens = list(lexer.lex_tokens()) + parser = Parser(tokens) + from pprint import pprint + try: + pprint(parser.instruction()) + except ParserError as e: + print_error(None, buf, e) + raise + print(parser.peek()) diff --git a/regs/assembler/lexer.py b/regs/assembler/lexer.py index 122b6a0..64a64fa 100644 --- a/regs/assembler/lexer.py +++ b/regs/assembler/lexer.py @@ -1,9 +1,7 @@ from dataclasses import dataclass from enum import Enum, auto from itertools import chain -from typing import Union - -from assembler import keywords +from typing import Union, Any DEBUG = True @@ -18,6 +16,10 @@ class TT(Enum): dot = auto() identifier = auto() keyword = auto() + colon = auto() + semicolon = auto() + bar = auto() + comma = auto() @dataclass class Token: @@ -26,7 +28,7 @@ class Token: col: int type: TT lexeme: memoryview - keyword: Union[keywords.VE, keywords.ME, keywords.KW] = None + keyword: Any = None identifier_characters = set(chain( b'abcdefghijklmnopqrstuvwxyz' @@ -39,12 +41,14 @@ class LexerError(Exception): pass class Lexer: - def __init__(self, buf: memoryview): + def __init__(self, buf: memoryview, find_keyword, emit_newlines=True): self.start_ix = 0 self.current_ix = 0 self.buf = memoryview(buf) self.line = 1 self.col = 0 + self.find_keyword = find_keyword + self.emit_newlines = emit_newlines def at_end_p(self): return self.current_ix >= len(self.buf) @@ -70,7 +74,7 @@ class Lexer: def identifier(self): while not self.at_end_p() and self.peek() in identifier_characters: self.advance() - keyword = keywords.find_keyword(self.lexeme()) + keyword = self.find_keyword(self.lexeme()) if keyword is not None: return Token(*self.pos(), TT.keyword, self.lexeme(), keyword) else: @@ -96,7 +100,15 @@ class Lexer: return Token(*self.pos(), TT.equal, self.lexeme()) elif c == ord('.'): return Token(*self.pos(), TT.dot, self.lexeme()) + elif c == ord('|'): + return Token(*self.pos(), TT.bar, self.lexeme()) + elif c == ord(':'): + return Token(*self.pos(), TT.colon, self.lexeme()) elif c == ord(';'): + return Token(*self.pos(), TT.semicolon, self.lexeme()) + elif c == ord(','): + return Token(*self.pos(), TT.comma, self.lexeme()) + elif c == ord('#'): while not self.at_end_p() and self.peek() != ord('\n'): self.advance() elif c == ord(' ') or c == ord('\r') or c == ord('\t'): @@ -105,11 +117,15 @@ class Lexer: pos = self.pos() self.line += 1 self.col = 0 - return Token(*pos, TT.eol, self.lexeme()) + if self.emit_newlines: + return Token(*pos, TT.eol, self.lexeme()) + else: + continue elif c in identifier_characters: return self.identifier() else: - raise LexerError(f"unexpected character at line:{self.line} col:{self.col}") + token = Token(*self.pos(), None, self.lexeme()) + raise LexerError(f"unexpected character at line:{self.line} col:{self.col}", token) def lex_tokens(self): while True: @@ -119,7 +135,16 @@ class Lexer: break if __name__ == "__main__": - test = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" - lexer = Lexer(test) - for token in lexer.lex_tokens(): - print(token) + def vs_test(): + from assembler.vskeywords import find_keyword + test = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" + lexer = Lexer(test, find_keyword) + for token in lexer.lex_tokens(): + print(token) + def fs_test(): + from assembler.fs.keywords import find_keyword + test = b"src0.rgb = temp[0] : temp[0].a = OP_RSQ |src0.r| ;" + lexer = Lexer(test, find_keyword) + for token in lexer.lex_tokens(): + print(token) + fs_test() diff --git a/regs/assembler/parser.py b/regs/assembler/parser.py index 1c7a651..ea6f3cb 100644 --- a/regs/assembler/parser.py +++ b/regs/assembler/parser.py @@ -1,125 +1,15 @@ -from itertools import pairwise -from dataclasses import dataclass -from typing import Union - -from assembler import lexer -from assembler.lexer import TT, Token -from assembler.keywords import KW, ME, VE - -""" -temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 -temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 -temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___ -temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000 -temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___ -temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000 -temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000 -temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000 -out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_ -out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1 -""" - -@dataclass -class DestinationOp: - type: KW - offset: int - write_enable: set[int] - opcode: Union[VE, ME] - sat: bool - macro: bool - -@dataclass -class SourceSwizzle: - select: tuple[int, int, int, int] - modifier: tuple[bool, bool, bool, bool] - -@dataclass -class Source: - type: KW - offset: int - swizzle: SourceSwizzle - -@dataclass -class Instruction: - destination_op: DestinationOp - source0: Source - source1: Source - source2: Source +from typing import Any class ParserError(Exception): pass -def identifier_to_number(token): - digits = set(b"0123456789") - - assert token.type is TT.identifier - if not all(d in digits for d in token.lexeme): - raise ParserError("expected number", token) - return int(bytes(token.lexeme), 10) - -def we_ord(c): - if c == ord("w"): - return 3 - else: - return c - ord("x") - -def parse_dest_write_enable(token): - we_chars = set(b"xyzw") - assert token.type is TT.identifier - we = bytes(token.lexeme).lower() - if not all(c in we_chars for c in we): - raise ParserError("expected destination write enable", token) - if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we): - raise ParserError("misleading non-sequential write enable", token) - return set(we_ord(c) for c in we) - -def parse_source_swizzle(token): - select_mapping = { - ord('x'): 0, - ord('y'): 1, - ord('z'): 2, - ord('w'): 3, - ord('0'): 4, - ord('1'): 5, - ord('h'): 6, - ord('_'): 7, - ord('u'): 7, - } - state = 0 - ix = 0 - swizzle_selects = [None] * 4 - swizzle_modifiers = [None] * 4 - lexeme = bytes(token.lexeme).lower() - while state < 4: - if ix >= len(token.lexeme): - raise ParserError("invalid source swizzle", token) - c = lexeme[ix] - if c == ord('-'): - if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None): - raise ParserError("invalid source swizzle modifier", token) - swizzle_modifiers[state] = True - elif c in select_mapping: - if swizzle_selects[state] is not None: - raise ParserError("invalid source swizzle select", token) - swizzle_selects[state] = select_mapping[c] - if swizzle_modifiers[state] is None: - swizzle_modifiers[state] = False - state += 1 - else: - raise ParserError("invalid source swizzle", token) - ix += 1 - if ix != len(lexeme): - raise ParserError("invalid source swizzle", token) - return SourceSwizzle(swizzle_selects, swizzle_modifiers) - -class Parser: - def __init__(self, tokens: list[lexer.Token]): +class BaseParser: + def __init__(self, tokens: list[Any]): self.current_ix = 0 self.tokens = tokens def peek(self, offset=0): token = self.tokens[self.current_ix + offset] - #print(token) return token def at_end_p(self): @@ -145,100 +35,3 @@ class Parser: if token.type != token_type1 and token.type != token_type2: raise ParserError(message, token) return token - - def destination_type(self): - token = self.consume(TT.keyword, "expected destination type") - destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} - if token.keyword not in destination_keywords: - raise ParserError("expected destination type", token) - return token.keyword - - def offset(self): - self.consume(TT.left_square, "expected offset") - identifier_token = self.consume(TT.identifier, "expected offset") - value = identifier_to_number(identifier_token) - self.consume(TT.right_square, "expected offset") - return value - - def opcode(self): - token = self.consume(TT.keyword, "expected opcode") - if type(token.keyword) != VE and type(token.keyword) != ME: - raise ParserError("expected opcode", token) - return token.keyword - - def destination_op(self): - destination_type = self.destination_type() - offset_value = self.offset() - self.consume(TT.dot, "expected write enable") - write_enable_token = self.consume(TT.identifier, "expected write enable token") - write_enable = parse_dest_write_enable(write_enable_token) - self.consume(TT.equal, "expected equals") - opcode = self.opcode() - sat = False - if self.match(TT.dot): - self.advance() - suffix = self.consume(TT.keyword, "expected saturation suffix") - if suffix.keyword is not KW.saturation: - raise ParserError("expected saturation suffix", token) - sat = True - - macro = False - return DestinationOp(type=destination_type, - offset=offset_value, - write_enable=write_enable, - opcode=opcode, - sat=sat, - macro=macro) - - def source_type(self): - token = self.consume(TT.keyword, "expected source type") - source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary} - if token.keyword not in source_keywords: - raise ParserError("expected source type", token) - return token.keyword - - def source_swizzle(self): - token = self.consume(TT.identifier, "expected source swizzle") - return parse_source_swizzle(token) - - def source(self): - "input[0].-y-_-0-_" - source_type = self.source_type() - offset = self.offset() - self.consume(TT.dot, "expected source swizzle") - source_swizzle = self.source_swizzle() - return Source(source_type, offset, source_swizzle) - - def instruction(self): - while self.match(TT.eol): - self.advance() - first_token = self.peek() - destination_op = self.destination_op() - source0 = self.source() - if self.match(TT.eol) or self.match(TT.eof): - source1 = None - else: - source1 = self.source() - if self.match(TT.eol) or self.match(TT.eof): - source2 = None - else: - source2 = self.source() - last_token = self.peek(-1) - self.consume_either(TT.eol, TT.eof, "expected newline or EOF") - return ( - Instruction(destination_op, source0, source1, source2), - (first_token.start_ix, last_token.start_ix + len(last_token.lexeme)) - ) - - def instructions(self): - while not self.match(TT.eof): - yield self.instruction() - -if __name__ == "__main__": - from assembler.lexer import Lexer - buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" - lexer = Lexer(buf) - tokens = list(lexer.lex_tokens()) - parser = Parser(tokens) - from pprint import pprint - pprint(parser.instruction()) diff --git a/regs/assembler/__main__.py b/regs/assembler/vs/__main__.py similarity index 61% rename from regs/assembler/__main__.py rename to regs/assembler/vs/__main__.py index 9dd3337..fb0bb99 100644 --- a/regs/assembler/__main__.py +++ b/regs/assembler/vs/__main__.py @@ -1,9 +1,10 @@ import sys from assembler.lexer import Lexer, LexerError -from assembler.parser import Parser, ParserError -from assembler.emitter import emit_instruction -from assembler.validator import validate_instruction +from assembler.vs.keywords import find_keyword +from assembler.vs.parser import Parser, ParserError +from assembler.vs.emitter import emit_instruction +from assembler.vs.validator import validate_instruction sample = b""" temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 @@ -19,40 +20,13 @@ out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1 """ def frontend_inner(buf): - lexer = Lexer(buf) + lexer = Lexer(buf, find_keyword) tokens = list(lexer.lex_tokens()) parser = Parser(tokens) for ins, start_end in parser.instructions(): ins = validate_instruction(ins) yield list(emit_instruction(ins)), start_end -def print_error(filename, buf, e): - assert len(e.args) == 2, e - message, token = e.args - lines = buf.splitlines() - line = lines[token.line - 1] - - error_name = str(type(e).__name__) - col_indent = ' ' * token.col - col_pointer = '^' * len(token.lexeme) - RED = "\033[0;31m" - DEFAULT = "\033[0;0m" - print(f'File: "{filename}", line {token.line}, column {token.col}\n', file=sys.stderr) - sys.stderr.write(' ') - wrote_default = False - for i, c in enumerate(line.decode('utf-8')): - if i == token.col: - sys.stderr.write(RED) - sys.stderr.write(c) - if i == token.col + len(token.lexeme): - wrote_default = True - sys.stderr.write(DEFAULT) - if not wrote_default: - sys.stderr.write(DEFAULT) - sys.stderr.write('\n') - print(f" {RED}{col_indent}{col_pointer}{DEFAULT}", file=sys.stderr) - print(f'{RED}{error_name}{DEFAULT}: {message}', file=sys.stderr) - def frontend(filename, buf): try: yield from frontend_inner(buf) diff --git a/regs/assembler/emitter.py b/regs/assembler/vs/emitter.py similarity index 97% rename from regs/assembler/emitter.py rename to regs/assembler/vs/emitter.py index d45538f..b9598d5 100644 --- a/regs/assembler/emitter.py +++ b/regs/assembler/vs/emitter.py @@ -1,5 +1,5 @@ -from assembler.keywords import ME, VE, MVE, KW -from assembler.parser import Instruction, DestinationOp, Source +from assembler.vs.keywords import ME, VE, MVE, KW +from assembler.vs.parser import Instruction, DestinationOp, Source import pvs_dst import pvs_src import pvs_dst_bits diff --git a/regs/assembler/keywords.py b/regs/assembler/vs/keywords.py similarity index 100% rename from regs/assembler/keywords.py rename to regs/assembler/vs/keywords.py diff --git a/regs/assembler/vs/parser.py b/regs/assembler/vs/parser.py new file mode 100644 index 0000000..0183338 --- /dev/null +++ b/regs/assembler/vs/parser.py @@ -0,0 +1,208 @@ +from itertools import pairwise +from dataclasses import dataclass +from typing import Union + +from assembler.parser import BaseParser, ParserError +from assembler.lexer import TT +from assembler.vs.keywords import KW, ME, VE, find_keyword + +""" +temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 +temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 +temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___ +temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000 +temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___ +temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000 +temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000 +temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000 +out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_ +out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1 +""" + +@dataclass +class DestinationOp: + type: KW + offset: int + write_enable: set[int] + opcode: Union[VE, ME] + sat: bool + macro: bool + +@dataclass +class SourceSwizzle: + select: tuple[int, int, int, int] + modifier: tuple[bool, bool, bool, bool] + +@dataclass +class Source: + type: KW + offset: int + swizzle: SourceSwizzle + +@dataclass +class Instruction: + destination_op: DestinationOp + source0: Source + source1: Source + source2: Source + +def identifier_to_number(token): + digits = set(b"0123456789") + + assert token.type is TT.identifier + if not all(d in digits for d in token.lexeme): + raise ParserError("expected number", token) + return int(bytes(token.lexeme), 10) + +def we_ord(c): + if c == ord("w"): + return 3 + else: + return c - ord("x") + +def parse_dest_write_enable(token): + we_chars = set(b"xyzw") + assert token.type is TT.identifier + we = bytes(token.lexeme).lower() + if not all(c in we_chars for c in we): + raise ParserError("expected destination write enable", token) + if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we): + raise ParserError("misleading non-sequential write enable", token) + return set(we_ord(c) for c in we) + +def parse_source_swizzle(token): + select_mapping = { + ord('x'): 0, + ord('y'): 1, + ord('z'): 2, + ord('w'): 3, + ord('0'): 4, + ord('1'): 5, + ord('h'): 6, + ord('_'): 7, + ord('u'): 7, + } + state = 0 + ix = 0 + swizzle_selects = [None] * 4 + swizzle_modifiers = [None] * 4 + lexeme = bytes(token.lexeme).lower() + while state < 4: + if ix >= len(token.lexeme): + raise ParserError("invalid source swizzle", token) + c = lexeme[ix] + if c == ord('-'): + if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None): + raise ParserError("invalid source swizzle modifier", token) + swizzle_modifiers[state] = True + elif c in select_mapping: + if swizzle_selects[state] is not None: + raise ParserError("invalid source swizzle select", token) + swizzle_selects[state] = select_mapping[c] + if swizzle_modifiers[state] is None: + swizzle_modifiers[state] = False + state += 1 + else: + raise ParserError("invalid source swizzle", token) + ix += 1 + if ix != len(lexeme): + raise ParserError("invalid source swizzle", token) + return SourceSwizzle(swizzle_selects, swizzle_modifiers) + +class Parser(BaseParser): + def destination_type(self): + token = self.consume(TT.keyword, "expected destination type") + destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} + if token.keyword not in destination_keywords: + raise ParserError("expected destination type", token) + return token.keyword + + def offset(self): + self.consume(TT.left_square, "expected offset") + identifier_token = self.consume(TT.identifier, "expected offset") + value = identifier_to_number(identifier_token) + self.consume(TT.right_square, "expected offset") + return value + + def opcode(self): + token = self.consume(TT.keyword, "expected opcode") + if type(token.keyword) != VE and type(token.keyword) != ME: + raise ParserError("expected opcode", token) + return token.keyword + + def destination_op(self): + destination_type = self.destination_type() + offset_value = self.offset() + self.consume(TT.dot, "expected write enable") + write_enable_token = self.consume(TT.identifier, "expected write enable token") + write_enable = parse_dest_write_enable(write_enable_token) + self.consume(TT.equal, "expected equals") + opcode = self.opcode() + sat = False + if self.match(TT.dot): + self.advance() + suffix = self.consume(TT.keyword, "expected saturation suffix") + if suffix.keyword is not KW.saturation: + raise ParserError("expected saturation suffix", token) + sat = True + + macro = False + return DestinationOp(type=destination_type, + offset=offset_value, + write_enable=write_enable, + opcode=opcode, + sat=sat, + macro=macro) + + def source_type(self): + token = self.consume(TT.keyword, "expected source type") + source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary} + if token.keyword not in source_keywords: + raise ParserError("expected source type", token) + return token.keyword + + def source_swizzle(self): + token = self.consume(TT.identifier, "expected source swizzle") + return parse_source_swizzle(token) + + def source(self): + "input[0].-y-_-0-_" + source_type = self.source_type() + offset = self.offset() + self.consume(TT.dot, "expected source swizzle") + source_swizzle = self.source_swizzle() + return Source(source_type, offset, source_swizzle) + + def instruction(self): + while self.match(TT.eol): + self.advance() + first_token = self.peek() + destination_op = self.destination_op() + source0 = self.source() + if self.match(TT.eol) or self.match(TT.eof): + source1 = None + else: + source1 = self.source() + if self.match(TT.eol) or self.match(TT.eof): + source2 = None + else: + source2 = self.source() + last_token = self.peek(-1) + self.consume_either(TT.eol, TT.eof, "expected newline or EOF") + return ( + Instruction(destination_op, source0, source1, source2), + (first_token.start_ix, last_token.start_ix + len(last_token.lexeme)) + ) + + def instructions(self): + while not self.match(TT.eof): + yield self.instruction() + +if __name__ == "__main__": + from assembler.lexer import Lexer + buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" + lexer = Lexer(buf, find_keyword) + tokens = list(lexer.lex_tokens()) + parser = Parser(tokens) + from pprint import pprint + pprint(parser.instruction()) diff --git a/regs/assembler/validator.py b/regs/assembler/vs/validator.py similarity index 93% rename from regs/assembler/validator.py rename to regs/assembler/vs/validator.py index 971ba2c..3d7f66c 100644 --- a/regs/assembler/validator.py +++ b/regs/assembler/vs/validator.py @@ -1,4 +1,4 @@ -from assembler.keywords import ME, VE, macro_vector_operations +from assembler.vs.keywords import ME, VE, macro_vector_operations class ValidatorError(Exception): pass