r500/regs/assembler/fs/parser.py

303 lines
9.5 KiB
Python

from enum import IntEnum
from typing import Literal, Union
from dataclasses import dataclass
from assembler.parser import BaseParser, ParserError
from assembler.lexer import TT, Token
from assembler.fs.keywords import KW, find_keyword
from assembler.error import print_error
class ALUMod(IntEnum):
nop = 0
neg = 1
abs = 2
nab = 3
@dataclass
class ALULetExpression:
src_keyword: Token
src_swizzle_identifier: Token
addr_keyword: Token
addr_value_identifier: Token
@dataclass
class DestAddrSwizzle:
dest_keyword: Token
addr_identifier: Token
swizzle_identifier: Token
@dataclass
class ALUSwizzleSel:
sel_keyword: Token
swizzle_identifier: Token
mod: ALUMod
@dataclass
class ALUOperation:
dest_addr_swizzles: list[DestAddrSwizzle]
opcode_keyword: Token
swizzle_sels: list[ALUSwizzleSel]
@dataclass
class ALUInstruction:
tags: set[Token]
let_expressions: list[ALULetExpression]
operations: list[ALUOperation]
@dataclass
class TEXOperation:
dest_addr_swizzles: list[DestAddrSwizzle]
opcode_keyword: Token
tex_id_identifier: Token
tex_dst_swizzle_identifier: Token
tex_src_address_identifier: Token
tex_src_swizzle_identifier: Token
@dataclass
class TEXInstruction:
tags: set[Token]
operation: TEXOperation
class Parser(BaseParser):
def alu_let_expression(self):
src_keyword = self.consume(TT.keyword, "expected src keyword")
self.consume(TT.dot, "expected dot")
src_swizzle_identifier = self.consume(TT.identifier, "expected src swizzle identifier")
self.consume(TT.equal, "expected equal")
addr_keyword = self.consume(TT.keyword, "expected addr keyword")
if addr_keyword.keyword in {KW.NEG2, KW.SUB, KW.ADD, KW.NEG}:
addr_value_identifier = None
else:
if addr_keyword.keyword is KW.FLOAT:
self.consume(TT.left_paren, "expected left paren")
else:
self.consume(TT.left_square, "expected left square")
addr_value_identifier = self.consume(TT.identifier, "expected address identifier")
if addr_keyword.keyword is KW.FLOAT:
self.consume(TT.right_paren, "expected right paren")
else:
self.consume(TT.right_square, "expected right square")
return ALULetExpression(
src_keyword,
src_swizzle_identifier,
addr_keyword,
addr_value_identifier,
)
def dest_addr_swizzle(self):
dest_keyword = self.consume(TT.keyword, "expected dest keyword")
self.consume(TT.left_square, "expected left square")
addr_identifier = self.consume(TT.identifier, "expected dest addr identifier")
self.consume(TT.right_square, "expected left square")
self.consume(TT.dot, "expected dot")
swizzle_identifier = self.consume(TT.identifier, "expected dest swizzle identifier")
self.consume(TT.equal, "expected equal")
return DestAddrSwizzle(
dest_keyword,
addr_identifier,
swizzle_identifier,
)
def alu_is_opcode(self):
opcode_keywords = {
KW.CMP, KW.CND, KW.COS, KW.D2A,
KW.DP , KW.DP3, KW.DP4, KW.EX2,
KW.FRC, KW.LN2, KW.MAD, KW.MAX,
KW.MDH, KW.MDV, KW.MIN, KW.RCP,
KW.RSQ, KW.SIN, KW.SOP,
}
if self.match(TT.keyword):
token = self.peek()
return token.keyword in opcode_keywords
return False
def alu_is_neg(self):
result = self.match(TT.minus)
if result:
self.advance()
return result
def alu_is_abs(self):
result = self.match(TT.bar)
if result:
self.advance()
return result
def alu_swizzle_sel(self):
neg = self.alu_is_neg()
abs = self.alu_is_abs()
sel_keyword = self.consume(TT.keyword, "expected sel keyword")
self.consume(TT.dot, "expected dot")
swizzle_identifier = self.consume(TT.identifier, "expected swizzle identifier")
if abs:
self.consume(TT.bar, "expected vertical bar")
mod_table = {
# (neg, abs)
(False, False): ALUMod.nop,
(False, True): ALUMod.abs,
(True, False): ALUMod.neg,
(True, True): ALUMod.nab,
}
mod = mod_table[(neg, abs)]
return ALUSwizzleSel(
sel_keyword,
swizzle_identifier,
mod,
)
def alu_operation(self):
dest_addr_swizzles = []
while not self.alu_is_opcode():
dest_addr_swizzles.append(self.dest_addr_swizzle())
opcode_keyword = self.consume(TT.keyword, "expected opcode keyword")
swizzle_sels = []
while not (self.match(TT.comma) or self.match(TT.semicolon)):
swizzle_sels.append(self.alu_swizzle_sel())
return ALUOperation(
dest_addr_swizzles,
opcode_keyword,
swizzle_sels
)
def alu_instruction(self, out: bool):
tags = list()
tag_keywords = set([KW.OUT, KW.TEX_SEM_WAIT, KW.NOP, KW.LAST, KW.ALU_WAIT])
while True:
if self.match(TT.keyword):
token = self.peek()
if token.keyword in tag_keywords:
self.advance()
tag_keywords.remove(token.keyword)
tags.append(token)
continue
break
let_expressions = []
while not self.match(TT.colon):
let_expressions.append(self.alu_let_expression())
if not self.match(TT.colon):
self.consume(TT.comma, "expected comma")
self.consume(TT.colon, "expected colon")
operations = []
while not self.match(TT.semicolon):
operations.append(self.alu_operation())
if not self.match(TT.semicolon):
self.consume(TT.comma, "expected comma")
self.consume(TT.semicolon, "expected semicolon")
return ALUInstruction(
tags,
let_expressions,
operations,
)
def tex_is_opcode(self):
opcode_keywords = {
KW.NOP, KW.LD, KW.TEXKILL, KW.PROJ,
KW.LODBIAS, KW.LOD, KW.DXDY,
}
if self.match(TT.keyword):
token = self.peek()
return token.keyword in opcode_keywords
return False
def tex_operation(self):
dest_addr_swizzles = []
while not self.tex_is_opcode():
dest_addr_swizzles.append(self.dest_addr_swizzle())
opcode_keyword = self.consume(TT.keyword, "expected tex opcode keyword")
self.consume_keyword(KW.TEX, "expected tex keyword")
self.consume(TT.left_square, "expected left square")
tex_id_identifier = self.consume(TT.identifier, "expected tex ID identifier")
self.consume(TT.right_square, "expected left square")
self.consume(TT.dot, "expected dot")
tex_dst_swizzle_identifier = self.consume(TT.identifier, "expected src swizzle")
self.consume_keyword(KW.TEMP, "expected temp keyword")
self.consume(TT.left_square, "expected left square")
tex_src_address_identifier = self.consume(TT.identifier, "expected tex address identifier")
self.consume(TT.right_square, "expected left square")
self.consume(TT.dot, "expected dot")
tex_src_swizzle_identifier = self.consume(TT.identifier, "expected dst swizzle")
return TEXOperation(
dest_addr_swizzles,
opcode_keyword,
tex_id_identifier,
tex_dst_swizzle_identifier,
tex_src_address_identifier,
tex_src_swizzle_identifier,
)
def tex_instruction(self):
tags = list()
tag_keywords = set([KW.OUT, KW.TEX_SEM_WAIT, KW.TEX_SEM_ACQUIRE, KW.NOP, KW.ALU_WAIT, KW.LAST])
while True:
if self.match(TT.keyword):
token = self.peek()
if token.keyword in tag_keywords:
self.advance()
tag_keywords.remove(token.keyword)
tags.append(token)
continue
break
operation = self.tex_operation()
self.consume(TT.semicolon, "expected semicolon")
return TEXInstruction(
tags,
operation,
)
def instruction(self):
out = False
if self.match_keyword(KW.TEX):
self.advance()
return self.tex_instruction()
else:
return self.alu_instruction(out=False)
def instructions(self):
while not self.match(TT.eof):
yield self.instruction()
if __name__ == "__main__":
from assembler.lexer import Lexer
buf = b"""
src0.a = float(0), src0.rgb = temp[0] :
out[0].none = temp[0].none = MAD src0.r src0.r src0.r ,
out[0].none = temp[0].r = DP3 src0.rg0 src0.rg0 ;
"""
buf = b"""
OUT TEX_SEM_WAIT
src0.a = float(0), src1.a = float(0), src2.a = float(0), srcp.a = neg2, src0.rgb = temp[0], src1.rgb = float(0), src2.rgb = float(0), srcp.rgb = neg2 :
out[0].none = temp[0].none = MAD src0.r src0.r src0.r ,
out[0].none = temp[0].r = DP3 src0.rg0 src0.rg0 src0.rrr ;
"""
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=True)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
try:
pprint(parser.instruction())
except ParserError as e:
print_error(None, buf, e)
raise
print(parser.peek())