r500/regs/assembler/vs/parser.py

207 lines
7.2 KiB
Python

from itertools import pairwise
from dataclasses import dataclass
from typing import Union
from assembler.parser import BaseParser, ParserError
from assembler.lexer import TT
from assembler.vs.keywords import KW, ME, VE, find_keyword
"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
@dataclass
class DestinationOp:
type: KW
offset: int
write_enable: set[int]
opcode: Union[VE, ME]
sat: bool
macro: bool
@dataclass
class SourceSwizzle:
select: tuple[int, int, int, int]
modifier: tuple[bool, bool, bool, bool]
@dataclass
class Source:
type: KW
offset: int
swizzle: SourceSwizzle
@dataclass
class Instruction:
destination_op: DestinationOp
source0: Source
source1: Source
source2: Source
def identifier_to_number(token):
digits = set(b"0123456789")
assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme):
raise ParserError("expected number", token)
return int(bytes(token.lexeme), 10)
def we_ord(c):
if c == ord("w"):
return 3
else:
return c - ord("x")
def parse_dest_write_enable(token):
we_chars = set(b"xyzw")
assert token.type is TT.identifier
we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("expected destination write enable", token)
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential write enable", token)
return set(we_ord(c) for c in we)
def parse_source_swizzle(token):
select_mapping = {
ord('x'): 0,
ord('y'): 1,
ord('z'): 2,
ord('w'): 3,
ord('0'): 4,
ord('1'): 5,
ord('h'): 6,
ord('_'): 7,
ord('u'): 7,
}
state = 0
ix = 0
swizzle_selects = [None] * 4
swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower()
while state < 4:
if ix >= len(token.lexeme):
raise ParserError("invalid source swizzle", token)
c = lexeme[ix]
if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParserError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True
elif c in select_mapping:
if swizzle_selects[state] is not None:
raise ParserError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False
state += 1
else:
raise ParserError("invalid source swizzle", token)
ix += 1
if ix != len(lexeme):
raise ParserError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
class Parser(BaseParser):
def destination_type(self):
token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
if token.keyword not in destination_keywords:
raise ParserError("expected destination type", token)
return token.keyword
def offset(self):
self.consume(TT.left_square, "expected offset")
identifier_token = self.consume(TT.identifier, "expected offset")
value = identifier_to_number(identifier_token)
self.consume(TT.right_square, "expected offset")
return value
def opcode(self):
token = self.consume(TT.keyword, "expected opcode")
if type(token.keyword) != VE and type(token.keyword) != ME:
raise ParserError("expected opcode", token)
return token.keyword
def destination_op(self):
destination_type = self.destination_type()
offset_value = self.offset()
self.consume(TT.dot, "expected write enable")
write_enable_token = self.consume(TT.identifier, "expected write enable token")
write_enable = parse_dest_write_enable(write_enable_token)
self.consume(TT.equal, "expected equals")
opcode = self.opcode()
sat = False
if self.match(TT.dot):
self.advance()
suffix = self.consume(TT.keyword, "expected saturation suffix")
if suffix.keyword is not KW.saturation:
raise ParserError("expected saturation suffix", token)
sat = True
macro = False
return DestinationOp(type=destination_type,
offset=offset_value,
write_enable=write_enable,
opcode=opcode,
sat=sat,
macro=macro)
def source_type(self):
token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords:
raise ParserError("expected source type", token)
return token.keyword
def source_swizzle(self):
token = self.consume(TT.identifier, "expected source swizzle")
return parse_source_swizzle(token)
def source(self):
"input[0].-y-_-0-_"
source_type = self.source_type()
offset = self.offset()
self.consume(TT.dot, "expected source swizzle")
source_swizzle = self.source_swizzle()
return Source(source_type, offset, source_swizzle)
def instruction(self):
first_token = self.peek()
destination_op = self.destination_op()
source0 = self.source()
if self.match(TT.semicolon) or self.match(TT.eof):
source1 = None
else:
source1 = self.source()
if self.match(TT.semicolon) or self.match(TT.eof):
source2 = None
else:
source2 = self.source()
last_token = self.peek(-1)
self.consume(TT.semicolon, "expected semicolon")
return (
Instruction(destination_op, source0, source1, source2),
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
)
def instructions(self):
while not self.match(TT.eof):
yield self.instruction()
if __name__ == "__main__":
from assembler.lexer import Lexer
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(buf, find_keyword)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
pprint(parser.instruction())