import lexer from lexer import TT from keywords import KW, ME, VE from itertools import pairwise from dataclasses import dataclass from typing import Union """ temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___ temp[0].x = VE_FRAC temp[0].x___ temp[0].0000 temp[0].0000 temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___ temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000 temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000 temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000 out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_ out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1 """ @dataclass class DestinationOp: type: KW offset: int write_enable: set[int] opcode: Union[VE, ME] @dataclass class SourceSwizzle: select: tuple[int, int, int, int] modifier: tuple[bool, bool, bool, bool] @dataclass class Source: type: KW offset: int swizzle: SourceSwizzle @dataclass class Instruction: destination_op: DestinationOp source0: Source source1: Source source2: Source class ParseError(Exception): pass def identifier_to_number(token): digits = set(b"0123456789") assert token.type is TT.identifier if not all(d in digits for d in token.lexeme): raise ParseError("expected number", token) return int(bytes(token.lexeme), 10) def parse_dest_write_enable(token): we_chars = set(b"xyzw") assert token.type is TT.identifier we = bytes(token.lexeme).lower() if not all(c in we_chars for c in we): raise ParseError("expected destination write enable", token) if not all(a < b for a, b in pairwise(we)) or len(set(we)) != len(we): raise ParseError("misleading non-sequential write enable", token) return set(c - ord('x') for c in we) def parse_source_swizzle(token): select_mapping = { ord('x'): 0, ord('y'): 1, ord('z'): 2, ord('w'): 3, ord('0'): 4, ord('1'): 5, ord('h'): 6, ord('_'): 7, ord('u'): 7, } state = 0 ix = 0 swizzle_selects = [None] * 4 swizzle_modifiers = [None] * 4 lexeme = bytes(token.lexeme).lower() while state < 4: if ix > len(token.lexeme): raise ParseError("invalid source swizzle", token) c = lexeme[ix] if c == ord('-'): if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None): raise ParseError("invalid source swizzle modifier", token) swizzle_modifiers[state] = True elif c in select_mapping: if swizzle_selects[state] is not None: raise ParseError("invalid source swizzle select", token) swizzle_selects[state] = select_mapping[c] if swizzle_modifiers[state] is None: swizzle_modifiers[state] = False state += 1 else: raise ParseError("invalid source swizzle", token) ix += 1 if ix != len(lexeme): raise ParseError("invalid source swizzle", token) return SourceSwizzle(swizzle_selects, swizzle_modifiers) class Parser: def __init__(self, tokens: list[lexer.Token]): self.current_ix = 0 self.tokens = tokens def peek(self): return self.tokens[self.current_ix] def at_end_p(self): return self.peek().type == TT.eof def advance(self): token = self.peek() self.current_ix += 1 return token def match(self, token_type, message): token = self.advance() return token.type == token_type def consume(self, token_type, message): token = self.advance() if token.type != token_type: raise ParseError(message, token) return token def consume_either(self, token_type1, token_type2, message): token = self.advance() if token.type != token_type1 and token.type != token_type2: raise ParseError(message, token) return token """ def consume_keyword(self, keyword, message): token = self.consume(TT.keyword, message) assert token.keyword is not None if token.keyword != keyword: raise ParseError(message, token) """ def destination_type(self): token = self.consume(TT.keyword, "expected destination type") destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} if token.keyword not in destination_keywords: raise ParseError("expected destination type", token) return token.keyword def offset(self): self.consume(TT.left_square, "expected offset") identifier_token = self.consume(TT.identifier, "expected offset") value = identifier_to_number(identifier_token) self.consume(TT.right_square, "expected offset") return value def opcode(self): token = self.consume(TT.keyword, "expected opcode") if type(token.keyword) != VE and type(token.keyword) != ME: raise ParseError("expected opcode", token) return token.keyword def destination_op(self): destination_type = self.destination_type() offset_value = self.offset() self.consume(TT.dot, "expected write enable") write_enable_token = self.consume(TT.identifier, "expected write enable token") write_enable = parse_dest_write_enable(write_enable_token) self.consume(TT.equal, "expected equals") opcode = self.opcode() return DestinationOp(destination_type, offset_value, write_enable, opcode) def source_type(self): token = self.consume(TT.keyword, "expected source type") source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary} if token.keyword not in source_keywords: raise ParseError("expected source type", token) return token.keyword def source_swizzle(self): token = self.consume(TT.identifier, "expected source swizzle") return parse_source_swizzle(token) def source(self): "input[0].-y-_-0-_" source_type = self.source_type() offset = self.offset() self.consume(TT.dot, "expected source swizzle") source_swizzle = self.source_swizzle() return Source(source_type, offset, source_swizzle) def instruction(self): destination_op = self.destination_op() source0 = self.source() source1 = self.source() source2 = self.source() self.consume_either(TT.eol, TT.eof, "expected newline or EOF") return Instruction(destination_op, source0, source1, source2) if __name__ == "__main__": from lexer import Lexer buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" lexer = Lexer(buf) tokens = list(lexer.lex_tokens()) parser = Parser(tokens) from pprint import pprint pprint(parser.instruction())