209 lines
7.3 KiB
Python
209 lines
7.3 KiB
Python
from itertools import pairwise
|
|
from dataclasses import dataclass
|
|
from typing import Union
|
|
|
|
from assembler.parser import BaseParser, ParserError
|
|
from assembler.lexer import TT
|
|
from assembler.vs.keywords import KW, ME, VE, find_keyword
|
|
|
|
"""
|
|
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
|
|
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
|
|
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
|
|
temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
|
|
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
|
|
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
|
|
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
|
|
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
|
|
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
|
|
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
|
|
"""
|
|
|
|
@dataclass
|
|
class DestinationOp:
|
|
type: KW
|
|
offset: int
|
|
write_enable: set[int]
|
|
opcode: Union[VE, ME]
|
|
sat: bool
|
|
macro: bool
|
|
|
|
@dataclass
|
|
class SourceSwizzle:
|
|
select: tuple[int, int, int, int]
|
|
modifier: tuple[bool, bool, bool, bool]
|
|
|
|
@dataclass
|
|
class Source:
|
|
type: KW
|
|
offset: int
|
|
swizzle: SourceSwizzle
|
|
|
|
@dataclass
|
|
class Instruction:
|
|
destination_op: DestinationOp
|
|
source0: Source
|
|
source1: Source
|
|
source2: Source
|
|
|
|
def identifier_to_number(token):
|
|
digits = set(b"0123456789")
|
|
|
|
assert token.type is TT.identifier
|
|
if not all(d in digits for d in token.lexeme):
|
|
raise ParserError("expected number", token)
|
|
return int(bytes(token.lexeme), 10)
|
|
|
|
def we_ord(c):
|
|
if c == ord("w"):
|
|
return 3
|
|
else:
|
|
return c - ord("x")
|
|
|
|
def parse_dest_write_enable(token):
|
|
we_chars = set(b"xyzw")
|
|
assert token.type is TT.identifier
|
|
we = bytes(token.lexeme).lower()
|
|
if not all(c in we_chars for c in we):
|
|
raise ParserError("expected destination write enable", token)
|
|
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
|
|
raise ParserError("misleading non-sequential write enable", token)
|
|
return set(we_ord(c) for c in we)
|
|
|
|
def parse_source_swizzle(token):
|
|
select_mapping = {
|
|
ord('x'): 0,
|
|
ord('y'): 1,
|
|
ord('z'): 2,
|
|
ord('w'): 3,
|
|
ord('0'): 4,
|
|
ord('1'): 5,
|
|
ord('h'): 6,
|
|
ord('_'): 7,
|
|
ord('u'): 7,
|
|
}
|
|
state = 0
|
|
ix = 0
|
|
swizzle_selects = [None] * 4
|
|
swizzle_modifiers = [None] * 4
|
|
lexeme = bytes(token.lexeme).lower()
|
|
while state < 4:
|
|
if ix >= len(token.lexeme):
|
|
raise ParserError("invalid source swizzle", token)
|
|
c = lexeme[ix]
|
|
if c == ord('-'):
|
|
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
|
|
raise ParserError("invalid source swizzle modifier", token)
|
|
swizzle_modifiers[state] = True
|
|
elif c in select_mapping:
|
|
if swizzle_selects[state] is not None:
|
|
raise ParserError("invalid source swizzle select", token)
|
|
swizzle_selects[state] = select_mapping[c]
|
|
if swizzle_modifiers[state] is None:
|
|
swizzle_modifiers[state] = False
|
|
state += 1
|
|
else:
|
|
raise ParserError("invalid source swizzle", token)
|
|
ix += 1
|
|
if ix != len(lexeme):
|
|
raise ParserError("invalid source swizzle", token)
|
|
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
|
|
|
|
class Parser(BaseParser):
|
|
def destination_type(self):
|
|
token = self.consume(TT.keyword, "expected destination type")
|
|
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
|
|
if token.keyword not in destination_keywords:
|
|
raise ParserError("expected destination type", token)
|
|
return token.keyword
|
|
|
|
def offset(self):
|
|
self.consume(TT.left_square, "expected offset")
|
|
identifier_token = self.consume(TT.identifier, "expected offset")
|
|
value = identifier_to_number(identifier_token)
|
|
self.consume(TT.right_square, "expected offset")
|
|
return value
|
|
|
|
def opcode(self):
|
|
token = self.consume(TT.keyword, "expected opcode")
|
|
if type(token.keyword) != VE and type(token.keyword) != ME:
|
|
raise ParserError("expected opcode", token)
|
|
return token.keyword
|
|
|
|
def destination_op(self):
|
|
destination_type = self.destination_type()
|
|
offset_value = self.offset()
|
|
self.consume(TT.dot, "expected write enable")
|
|
write_enable_token = self.consume(TT.identifier, "expected write enable token")
|
|
write_enable = parse_dest_write_enable(write_enable_token)
|
|
self.consume(TT.equal, "expected equals")
|
|
opcode = self.opcode()
|
|
sat = False
|
|
if self.match(TT.dot):
|
|
self.advance()
|
|
suffix = self.consume(TT.keyword, "expected saturation suffix")
|
|
if suffix.keyword is not KW.saturation:
|
|
raise ParserError("expected saturation suffix", token)
|
|
sat = True
|
|
|
|
macro = False
|
|
return DestinationOp(type=destination_type,
|
|
offset=offset_value,
|
|
write_enable=write_enable,
|
|
opcode=opcode,
|
|
sat=sat,
|
|
macro=macro)
|
|
|
|
def source_type(self):
|
|
token = self.consume(TT.keyword, "expected source type")
|
|
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
|
|
if token.keyword not in source_keywords:
|
|
raise ParserError("expected source type", token)
|
|
return token.keyword
|
|
|
|
def source_swizzle(self):
|
|
token = self.consume(TT.identifier, "expected source swizzle")
|
|
return parse_source_swizzle(token)
|
|
|
|
def source(self):
|
|
"input[0].-y-_-0-_"
|
|
source_type = self.source_type()
|
|
offset = self.offset()
|
|
self.consume(TT.dot, "expected source swizzle")
|
|
source_swizzle = self.source_swizzle()
|
|
return Source(source_type, offset, source_swizzle)
|
|
|
|
def instruction(self):
|
|
while self.match(TT.eol):
|
|
self.advance()
|
|
first_token = self.peek()
|
|
destination_op = self.destination_op()
|
|
source0 = self.source()
|
|
if self.match(TT.eol) or self.match(TT.eof):
|
|
source1 = None
|
|
else:
|
|
source1 = self.source()
|
|
if self.match(TT.eol) or self.match(TT.eof):
|
|
source2 = None
|
|
else:
|
|
source2 = self.source()
|
|
last_token = self.peek(-1)
|
|
self.consume_either(TT.eol, TT.eof, "expected newline or EOF")
|
|
return (
|
|
Instruction(destination_op, source0, source1, source2),
|
|
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
|
|
)
|
|
|
|
def instructions(self):
|
|
while not self.match(TT.eof):
|
|
yield self.instruction()
|
|
|
|
if __name__ == "__main__":
|
|
from assembler.lexer import Lexer
|
|
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
|
|
lexer = Lexer(buf, find_keyword)
|
|
tokens = list(lexer.lex_tokens())
|
|
parser = Parser(tokens)
|
|
from pprint import pprint
|
|
pprint(parser.instruction())
|