r500/regs/assembler/fs/alu_validator.py

494 lines
14 KiB
Python

from pprint import pprint
from dataclasses import dataclass
from enum import Enum, IntEnum, auto
from collections import OrderedDict
from typing import Union
from assembler.fs.parser import ALUMod
from assembler.fs.keywords import _keyword_to_string, KW
from assembler.error import print_error
from assembler.validator import ValidatorError
from assembler.fs.common_validator import RGBMask, AlphaMask, InstructionType
from assembler.fs.common_validator import validate_identifier_number, validate_dest_keyword, keywords_to_string
class SrcAddrType(Enum):
temp = auto()
const = auto()
float = auto()
@dataclass
class SrcAddr:
type: SrcAddrType
value: int
class SrcMod(IntEnum):
nop = 0
neg = 1
abs = 2
nab = 3
class SrcpOp(IntEnum):
neg2 = 0
sub = 1
add = 2
neg = 3
@dataclass
class Addr:
src0: SrcAddr = None
src1: SrcAddr = None
src2: SrcAddr = None
srcp: SrcpOp = None
@dataclass
class AddrRGBAlpha:
alpha: Addr
rgb: Addr
class Unit(Enum):
alpha = auto()
rgb = auto()
class RGBOp(IntEnum):
MAD = 0
DP3 = 1
DP4 = 2
D2A = 3
MIN = 4
MAX = 5
CND = 7
CMP = 8
FRC = 9
SOP = 10
MDH = 11
MDV = 12
class AlphaOp(IntEnum):
MAD = 0
DP = 1
MIN = 2
MAX = 3
CND = 5
CMP = 6
FRC = 7
EX2 = 8
LN2 = 9
RCP = 10
RSQ = 11
SIN = 12
COS = 13
MDH = 14
MDV = 15
@dataclass
class RGBDest:
addrd: int
target: int
wmask: RGBMask
omask: RGBMask
@dataclass
class AlphaDest:
addrd: int
target: int
wmask: AlphaMask
omask: AlphaMask
class SwizzleSelSrc(IntEnum):
src0 = 0
src1 = 1
src2 = 2
srcp = 3
class Swizzle(IntEnum):
r = 0
g = 1
b = 2
a = 3
zero = 4
half = 5
one = 6
unused = 7
@dataclass
class SwizzleSel:
src: SwizzleSelSrc
swizzle: list[Swizzle]
mod: ALUMod
@dataclass
class AlphaOperation:
dest: AlphaDest
opcode: AlphaOp
sels: list[SwizzleSel]
@dataclass
class RGBOperation:
dest: RGBDest
opcode: RGBOp
sels: list[SwizzleSel]
@dataclass
class Instruction:
tags: set[Union[KW.OUT, KW.TEX_SEM_WAIT, KW.NOP]]
addr: Addr
alpha_op: AlphaOperation
rgb_op: RGBOperation
def validate_instruction_let_expressions(let_expressions):
src_keywords = [KW.SRC0, KW.SRC1, KW.SRC2, KW.SRCP]
src_keyword_strs = keywords_to_string(src_keywords)
rgb_alpha_swizzles = [b"rgb", b"a"]
addr_rgb_alpha = AddrRGBAlpha(Addr(), Addr())
def set_src_by_keyword(addr, keyword, value):
if keyword == KW.SRC0:
addr.src0 = value
elif keyword == KW.SRC1:
addr.src1 = value
elif keyword == KW.SRC2:
addr.src2 = value
elif keyword == KW.SRCP:
addr.srcp = value
else:
assert False, keyword
def src_value(expr, src):
if src in {KW.SRC0, KW.SRC1, KW.SRC2}:
keyword_to_src_addr_type = OrderedDict([
(KW.TEMP, SrcAddrType.temp),
(KW.CONST, SrcAddrType.const),
(KW.FLOAT, SrcAddrType.float),
])
src_addr_type_strs = keywords_to_string(keyword_to_src_addr_type.keys())
type_kw = expr.addr_keyword.keyword
if type_kw not in keyword_to_src_addr_type:
raise ValidatorError(f"invalid src addr type, expected one of {src_addr_type_strs}", expr.addr_keyword)
type = keyword_to_src_addr_type[type_kw]
value = validate_identifier_number(expr.addr_value_identifier)
if type is SrcAddrType.float:
if value >= 128:
raise ValidatorError(f"invalid float value", expr.addr_value_identifier)
elif type is SrcAddrType.temp:
if value >= 128:
raise ValidatorError(f"invalid temp value", expr.addr_value_identifier)
elif type is SrcAddrType.const:
if value >= 256:
raise ValidatorError(f"invalid const value", expr.addr_value_identifier)
else:
assert False, (id(type), id(SrcAddrType.float))
return SrcAddr(
type,
value,
)
elif src == KW.SRCP:
keyword_to_srcp_op = OrderedDict([
(KW.NEG2, SrcpOp.neg2),
(KW.SUB, SrcpOp.sub),
(KW.ADD, SrcpOp.add),
(KW.NEG, SrcpOp.neg),
])
srcp_op_strs = keywords_to_string(keyword_to_srcp_op.keys())
op = expr.addr_keyword.keyword
if op not in keyword_to_srcp_op:
raise ValidatorError(f"invalid srcp op, expected one of {srcp_op_strs}", expr.addr_keyword)
return keyword_to_srcp_op[op]
else:
assert False, src
sources = set()
for expr in let_expressions:
src = expr.src_keyword.keyword
if src not in src_keywords:
raise ValidatorError(f"invalid src keyword, expected one of {src_keyword_strs}", expr.src_keyword)
src_swizzle = expr.src_swizzle_identifier.lexeme.lower()
if src_swizzle not in rgb_alpha_swizzles:
raise ValidatorError(f"invalid src swizzle, expected one of {rgb_alpha_swizzles}", expr.src_swizzle_identifier)
source = (_keyword_to_string[src].lower(), src_swizzle)
if source in sources:
raise ValidatorError(f"duplicate source/swizzle in let expressions", expr.src_swizzle_identifier)
sources.add(source)
value = src_value(expr, src)
if src_swizzle == b"a":
set_src_by_keyword(addr_rgb_alpha.alpha, src, value)
elif src_swizzle == b"rgb":
set_src_by_keyword(addr_rgb_alpha.rgb, src, value)
else:
assert False, src_swizzle
return addr_rgb_alpha
def prevalidate_mask(dest_addr_swizzle, valid_masks):
# we don't know yet whether this is an Alpha operation or an RGB operation
swizzle_str = dest_addr_swizzle.swizzle_identifier.lexeme
if swizzle_str.lower() not in valid_masks:
raise ValidatorError(f"invalid write mask, expected one of {valid_masks}", dest_addr_swizzle.swizzle_identifier)
mask = swizzle_str.lower()
return mask
rgb_op_kws = OrderedDict([
(KW.MAD, RGBOp.MAD),
(KW.DP3, RGBOp.DP3),
(KW.DP4, RGBOp.DP4),
(KW.D2A, RGBOp.D2A),
(KW.MIN, RGBOp.MIN),
(KW.MAX, RGBOp.MAX),
(KW.CND, RGBOp.CND),
(KW.CMP, RGBOp.CMP),
(KW.FRC, RGBOp.FRC),
(KW.SOP, RGBOp.SOP),
(KW.MDH, RGBOp.MDH),
(KW.MDV, RGBOp.MDV)
])
alpha_op_kws = OrderedDict([
(KW.MAD, AlphaOp.MAD),
(KW.DP, AlphaOp.DP),
(KW.MIN, AlphaOp.MIN),
(KW.MAX, AlphaOp.MAX),
(KW.CND, AlphaOp.CND),
(KW.CMP, AlphaOp.CMP),
(KW.FRC, AlphaOp.FRC),
(KW.EX2, AlphaOp.EX2),
(KW.LN2, AlphaOp.LN2),
(KW.RCP, AlphaOp.RCP),
(KW.RSQ, AlphaOp.RSQ),
(KW.SIN, AlphaOp.SIN),
(KW.COS, AlphaOp.COS),
(KW.MDH, AlphaOp.MDH),
(KW.MDV, AlphaOp.MDV)
])
rgb_masks = OrderedDict([
(b"none" , RGBMask.NONE),
(b"r" , RGBMask.R),
(b"g" , RGBMask.G),
(b"rg" , RGBMask.RG),
(b"b" , RGBMask.B),
(b"rb" , RGBMask.RB),
(b"gb" , RGBMask.GB),
(b"rgb" , RGBMask.RGB),
])
alpha_masks = OrderedDict([
(b"none" , AlphaMask.NONE),
(b"a" , AlphaMask.A),
])
alpha_only_ops = set(alpha_op_kws.keys()) - set(rgb_op_kws.keys())
rgb_only_ops = set(rgb_op_kws.keys()) - set(alpha_op_kws.keys())
all_ops = set(rgb_op_kws.keys()) | set(alpha_op_kws.keys())
alpha_only_masks = set(alpha_masks.keys()) - set(rgb_masks.keys())
rgb_only_masks = set(rgb_masks.keys()) - set(alpha_masks.keys())
all_masks = set(rgb_masks.keys()) | set(alpha_masks.keys())
def infer_operation_units(operations):
if len(operations) > 2:
raise ValidatorError("too many operations in instruction", operations[-1].opcode_keyword)
units = [None, None]
for i, operation in enumerate(operations):
opcode = operation.opcode_keyword.keyword
if opcode not in all_ops:
raise ValidatorError(f"invalid opcode keyword, expected one of {all_ops}", operation.opcode_keyword)
if len(operation.dest_addr_swizzles) > 2:
raise ValidationError("too many destinations in instruction", operation.dest_addr_swizzles[-1])
masks = set(prevalidate_mask(dest_addr_swizzle, all_masks) for dest_addr_swizzle in operation.dest_addr_swizzles)
def infer_opcode_unit():
if opcode in alpha_only_ops:
return Unit.alpha
if opcode in rgb_only_ops:
return Unit.rgb
return None
def infer_mask_unit():
if any(mask in alpha_only_masks for mask in masks):
return Unit.alpha
if any(mask in rgb_only_masks for mask in masks):
return Unit.rgb
return None
opcode_unit = infer_opcode_unit()
mask_unit = infer_mask_unit()
if opcode_unit is not None and mask_unit is not None and opcode_unit != mask_unit:
raise ValidatorError(f"contradictory {mask_unit.name} write mask for {opcode_unit.name} opcode", operation.opcode_keyword)
units[i] = opcode_unit or mask_unit
if units[0] == units[1]:
raise ValidatorError(f"invalid duplicate use of {units[1].name} unit", operations[1].opcode_keyword)
other_unit = {
Unit.alpha: Unit.rgb,
Unit.rgb: Unit.alpha,
}
if units[0] is None:
units[0] = other_unit[units[1]]
if units[1] is None:
units[1] = other_unit[units[0]]
assert units[0] is not None
assert units[1] is not None
assert units[0] != units[1]
for i, operation in enumerate(operations):
yield units[i], operation
def validate_instruction_operation_dest(dest_addr_swizzles, mask_lookup, type_cls):
addrd = None
target = None
wmask = None
omask = None
for dest_addr_swizzle in dest_addr_swizzles:
dest = validate_dest_keyword(dest_addr_swizzle.dest_keyword)
addr = validate_identifier_number(dest_addr_swizzle.addr_identifier)
mask = mask_lookup[dest_addr_swizzle.swizzle_identifier.lexeme.lower()]
if dest == KW.OUT:
omask = mask
target = addr
elif dest == KW.TEMP:
wmask = mask
addrd = addr
else:
assert False, dest
return type_cls(
addrd=addrd,
target=target,
wmask=wmask,
omask=omask
)
swizzle_sel_src_kws = OrderedDict([
(KW.SRC0, SwizzleSelSrc.src0),
(KW.SRC1, SwizzleSelSrc.src1),
(KW.SRC2, SwizzleSelSrc.src2),
(KW.SRCP, SwizzleSelSrc.srcp),
])
swizzle_kws = OrderedDict([
(ord("r"), Swizzle.r),
(ord("g"), Swizzle.g),
(ord("b"), Swizzle.b),
(ord("a"), Swizzle.a),
(ord("0"), Swizzle.zero),
(ord("h"), Swizzle.half),
(ord("1"), Swizzle.one),
(ord("_"), Swizzle.unused),
])
def validate_instruction_operation_sels(swizzle_sels, is_alpha):
if len(swizzle_sels) > 3:
raise ValidatorError("too many swizzle sels", swizzle_sels[-1].sel_keyword)
sels = []
for swizzle_sel in swizzle_sels:
if swizzle_sel.sel_keyword.keyword not in swizzle_sel_src_kws:
raise ValidatorError("invalid swizzle src", swizzle_sel.sel_keyword.keyword)
src = swizzle_sel_src_kws[swizzle_sel.sel_keyword.keyword]
swizzle_lexeme = swizzle_sel.swizzle_identifier.lexeme.lower()
swizzles_length = 1 if is_alpha else 3
if len(swizzle_lexeme) != swizzles_length:
raise ValidatorError("invalid swizzle length", swizzle_sel.swizzle_identifier)
if not all(c in swizzle_kws for c in swizzle_lexeme):
raise ValidatorError("invalid swizzle characters", swizzle_sel.swizzle_identifier)
swizzle = [
swizzle_kws[c] for c in swizzle_lexeme
]
mod = swizzle_sel.mod
sels.append(SwizzleSel(src, swizzle, mod))
return sels
def validate_alpha_instruction_operation(operation):
dest = validate_instruction_operation_dest(operation.dest_addr_swizzles,
mask_lookup=alpha_masks,
type_cls=AlphaDest)
opcode = alpha_op_kws[operation.opcode_keyword.keyword]
sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=True)
return AlphaOperation(
dest,
opcode,
sels
)
def validate_rgb_instruction_operation(operation):
dest = validate_instruction_operation_dest(operation.dest_addr_swizzles,
mask_lookup=rgb_masks,
type_cls=RGBDest)
opcode = rgb_op_kws[operation.opcode_keyword.keyword]
sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=False)
return RGBOperation(
dest,
opcode,
sels
)
def validate_instruction_operations(operations):
for unit, operation in infer_operation_units(operations):
if unit is Unit.alpha:
yield validate_alpha_instruction_operation(operation)
elif unit is Unit.rgb:
yield validate_rgb_instruction_operation(operation)
else:
assert False, unit
def validate_instruction(ins):
addr_rgb_alpha = validate_instruction_let_expressions(ins.let_expressions)
tags = set([tag.keyword for tag in ins.tags])
instruction = Instruction(
tags,
addr_rgb_alpha,
None,
None
)
for op in validate_instruction_operations(ins.operations):
if type(op) is RGBOperation:
instruction.rgb_op = op
elif type(op) is AlphaOperation:
instruction.alpha_op = op
else:
assert False, op
return instruction
if __name__ == "__main__":
from assembler.lexer import Lexer, LexerError
from assembler.fs.parser import Parser, ParserError
from assembler.fs.keywords import find_keyword
buf = b"""
src0.a = float(0), src0.rgb = temp[0] , srcp.a = neg :
out[0].none = temp[0].none = MAD src0.r src0.r src0.r ,
out[0].none = temp[0].r = DP3 src0.rg0 src0.rg0 ;
"""
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=True)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
try:
ins_ast = parser.instruction()
pprint(validate_instruction(ins_ast))
except LexerError as e:
print_error(None, buf, e)
raise
except ParserError as e:
print_error(None, buf, e)
raise
except ValidatorError as e:
print_error(None, buf, e)
raise