494 lines
14 KiB
Python
494 lines
14 KiB
Python
from pprint import pprint
|
|
from dataclasses import dataclass
|
|
from enum import Enum, IntEnum, auto
|
|
from collections import OrderedDict
|
|
from typing import Union
|
|
|
|
from assembler.fs.parser import ALUMod
|
|
from assembler.fs.keywords import _keyword_to_string, KW
|
|
from assembler.error import print_error
|
|
from assembler.validator import ValidatorError
|
|
from assembler.fs.common_validator import RGBMask, AlphaMask, InstructionType
|
|
from assembler.fs.common_validator import validate_identifier_number, validate_dest_keyword, keywords_to_string
|
|
|
|
class SrcAddrType(Enum):
|
|
temp = auto()
|
|
const = auto()
|
|
float = auto()
|
|
|
|
@dataclass
|
|
class SrcAddr:
|
|
type: SrcAddrType
|
|
value: int
|
|
|
|
class SrcMod(IntEnum):
|
|
nop = 0
|
|
neg = 1
|
|
abs = 2
|
|
nab = 3
|
|
|
|
class SrcpOp(IntEnum):
|
|
neg2 = 0
|
|
sub = 1
|
|
add = 2
|
|
neg = 3
|
|
|
|
@dataclass
|
|
class Addr:
|
|
src0: SrcAddr = None
|
|
src1: SrcAddr = None
|
|
src2: SrcAddr = None
|
|
srcp: SrcpOp = None
|
|
|
|
@dataclass
|
|
class AddrRGBAlpha:
|
|
alpha: Addr
|
|
rgb: Addr
|
|
|
|
class Unit(Enum):
|
|
alpha = auto()
|
|
rgb = auto()
|
|
|
|
class RGBOp(IntEnum):
|
|
MAD = 0
|
|
DP3 = 1
|
|
DP4 = 2
|
|
D2A = 3
|
|
MIN = 4
|
|
MAX = 5
|
|
CND = 7
|
|
CMP = 8
|
|
FRC = 9
|
|
SOP = 10
|
|
MDH = 11
|
|
MDV = 12
|
|
|
|
class AlphaOp(IntEnum):
|
|
MAD = 0
|
|
DP = 1
|
|
MIN = 2
|
|
MAX = 3
|
|
CND = 5
|
|
CMP = 6
|
|
FRC = 7
|
|
EX2 = 8
|
|
LN2 = 9
|
|
RCP = 10
|
|
RSQ = 11
|
|
SIN = 12
|
|
COS = 13
|
|
MDH = 14
|
|
MDV = 15
|
|
|
|
@dataclass
|
|
class RGBDest:
|
|
addrd: int
|
|
target: int
|
|
wmask: RGBMask
|
|
omask: RGBMask
|
|
|
|
@dataclass
|
|
class AlphaDest:
|
|
addrd: int
|
|
target: int
|
|
wmask: AlphaMask
|
|
omask: AlphaMask
|
|
|
|
class SwizzleSelSrc(IntEnum):
|
|
src0 = 0
|
|
src1 = 1
|
|
src2 = 2
|
|
srcp = 3
|
|
|
|
class Swizzle(IntEnum):
|
|
r = 0
|
|
g = 1
|
|
b = 2
|
|
a = 3
|
|
zero = 4
|
|
half = 5
|
|
one = 6
|
|
unused = 7
|
|
|
|
@dataclass
|
|
class SwizzleSel:
|
|
src: SwizzleSelSrc
|
|
swizzle: list[Swizzle]
|
|
mod: ALUMod
|
|
|
|
@dataclass
|
|
class AlphaOperation:
|
|
dest: AlphaDest
|
|
opcode: AlphaOp
|
|
sels: list[SwizzleSel]
|
|
|
|
@dataclass
|
|
class RGBOperation:
|
|
dest: RGBDest
|
|
opcode: RGBOp
|
|
sels: list[SwizzleSel]
|
|
|
|
@dataclass
|
|
class Instruction:
|
|
tags: set[Union[KW.OUT, KW.TEX_SEM_WAIT, KW.NOP]]
|
|
addr: Addr
|
|
alpha_op: AlphaOperation
|
|
rgb_op: RGBOperation
|
|
|
|
def validate_instruction_let_expressions(let_expressions):
|
|
src_keywords = [KW.SRC0, KW.SRC1, KW.SRC2, KW.SRCP]
|
|
src_keyword_strs = keywords_to_string(src_keywords)
|
|
rgb_alpha_swizzles = [b"rgb", b"a"]
|
|
|
|
addr_rgb_alpha = AddrRGBAlpha(Addr(), Addr())
|
|
|
|
def set_src_by_keyword(addr, keyword, value):
|
|
if keyword == KW.SRC0:
|
|
addr.src0 = value
|
|
elif keyword == KW.SRC1:
|
|
addr.src1 = value
|
|
elif keyword == KW.SRC2:
|
|
addr.src2 = value
|
|
elif keyword == KW.SRCP:
|
|
addr.srcp = value
|
|
else:
|
|
assert False, keyword
|
|
|
|
def src_value(expr, src):
|
|
if src in {KW.SRC0, KW.SRC1, KW.SRC2}:
|
|
keyword_to_src_addr_type = OrderedDict([
|
|
(KW.TEMP, SrcAddrType.temp),
|
|
(KW.CONST, SrcAddrType.const),
|
|
(KW.FLOAT, SrcAddrType.float),
|
|
])
|
|
src_addr_type_strs = keywords_to_string(keyword_to_src_addr_type.keys())
|
|
type_kw = expr.addr_keyword.keyword
|
|
if type_kw not in keyword_to_src_addr_type:
|
|
raise ValidatorError(f"invalid src addr type, expected one of {src_addr_type_strs}", expr.addr_keyword)
|
|
|
|
type = keyword_to_src_addr_type[type_kw]
|
|
value = validate_identifier_number(expr.addr_value_identifier)
|
|
if type is SrcAddrType.float:
|
|
if value >= 128:
|
|
raise ValidatorError(f"invalid float value", expr.addr_value_identifier)
|
|
elif type is SrcAddrType.temp:
|
|
if value >= 128:
|
|
raise ValidatorError(f"invalid temp value", expr.addr_value_identifier)
|
|
elif type is SrcAddrType.const:
|
|
if value >= 256:
|
|
raise ValidatorError(f"invalid const value", expr.addr_value_identifier)
|
|
else:
|
|
assert False, (id(type), id(SrcAddrType.float))
|
|
|
|
return SrcAddr(
|
|
type,
|
|
value,
|
|
)
|
|
elif src == KW.SRCP:
|
|
keyword_to_srcp_op = OrderedDict([
|
|
(KW.NEG2, SrcpOp.neg2),
|
|
(KW.SUB, SrcpOp.sub),
|
|
(KW.ADD, SrcpOp.add),
|
|
(KW.NEG, SrcpOp.neg),
|
|
])
|
|
srcp_op_strs = keywords_to_string(keyword_to_srcp_op.keys())
|
|
op = expr.addr_keyword.keyword
|
|
if op not in keyword_to_srcp_op:
|
|
raise ValidatorError(f"invalid srcp op, expected one of {srcp_op_strs}", expr.addr_keyword)
|
|
return keyword_to_srcp_op[op]
|
|
else:
|
|
assert False, src
|
|
|
|
sources = set()
|
|
|
|
for expr in let_expressions:
|
|
src = expr.src_keyword.keyword
|
|
if src not in src_keywords:
|
|
raise ValidatorError(f"invalid src keyword, expected one of {src_keyword_strs}", expr.src_keyword)
|
|
|
|
src_swizzle = expr.src_swizzle_identifier.lexeme.lower()
|
|
if src_swizzle not in rgb_alpha_swizzles:
|
|
raise ValidatorError(f"invalid src swizzle, expected one of {rgb_alpha_swizzles}", expr.src_swizzle_identifier)
|
|
|
|
source = (_keyword_to_string[src].lower(), src_swizzle)
|
|
if source in sources:
|
|
raise ValidatorError(f"duplicate source/swizzle in let expressions", expr.src_swizzle_identifier)
|
|
sources.add(source)
|
|
|
|
value = src_value(expr, src)
|
|
|
|
if src_swizzle == b"a":
|
|
set_src_by_keyword(addr_rgb_alpha.alpha, src, value)
|
|
elif src_swizzle == b"rgb":
|
|
set_src_by_keyword(addr_rgb_alpha.rgb, src, value)
|
|
else:
|
|
assert False, src_swizzle
|
|
|
|
return addr_rgb_alpha
|
|
|
|
def prevalidate_mask(dest_addr_swizzle, valid_masks):
|
|
# we don't know yet whether this is an Alpha operation or an RGB operation
|
|
swizzle_str = dest_addr_swizzle.swizzle_identifier.lexeme
|
|
if swizzle_str.lower() not in valid_masks:
|
|
raise ValidatorError(f"invalid write mask, expected one of {valid_masks}", dest_addr_swizzle.swizzle_identifier)
|
|
|
|
mask = swizzle_str.lower()
|
|
return mask
|
|
|
|
rgb_op_kws = OrderedDict([
|
|
(KW.MAD, RGBOp.MAD),
|
|
(KW.DP3, RGBOp.DP3),
|
|
(KW.DP4, RGBOp.DP4),
|
|
(KW.D2A, RGBOp.D2A),
|
|
(KW.MIN, RGBOp.MIN),
|
|
(KW.MAX, RGBOp.MAX),
|
|
(KW.CND, RGBOp.CND),
|
|
(KW.CMP, RGBOp.CMP),
|
|
(KW.FRC, RGBOp.FRC),
|
|
(KW.SOP, RGBOp.SOP),
|
|
(KW.MDH, RGBOp.MDH),
|
|
(KW.MDV, RGBOp.MDV)
|
|
])
|
|
alpha_op_kws = OrderedDict([
|
|
(KW.MAD, AlphaOp.MAD),
|
|
(KW.DP, AlphaOp.DP),
|
|
(KW.MIN, AlphaOp.MIN),
|
|
(KW.MAX, AlphaOp.MAX),
|
|
(KW.CND, AlphaOp.CND),
|
|
(KW.CMP, AlphaOp.CMP),
|
|
(KW.FRC, AlphaOp.FRC),
|
|
(KW.EX2, AlphaOp.EX2),
|
|
(KW.LN2, AlphaOp.LN2),
|
|
(KW.RCP, AlphaOp.RCP),
|
|
(KW.RSQ, AlphaOp.RSQ),
|
|
(KW.SIN, AlphaOp.SIN),
|
|
(KW.COS, AlphaOp.COS),
|
|
(KW.MDH, AlphaOp.MDH),
|
|
(KW.MDV, AlphaOp.MDV)
|
|
])
|
|
|
|
rgb_masks = OrderedDict([
|
|
(b"none" , RGBMask.NONE),
|
|
(b"r" , RGBMask.R),
|
|
(b"g" , RGBMask.G),
|
|
(b"rg" , RGBMask.RG),
|
|
(b"b" , RGBMask.B),
|
|
(b"rb" , RGBMask.RB),
|
|
(b"gb" , RGBMask.GB),
|
|
(b"rgb" , RGBMask.RGB),
|
|
])
|
|
|
|
alpha_masks = OrderedDict([
|
|
(b"none" , AlphaMask.NONE),
|
|
(b"a" , AlphaMask.A),
|
|
])
|
|
|
|
alpha_only_ops = set(alpha_op_kws.keys()) - set(rgb_op_kws.keys())
|
|
rgb_only_ops = set(rgb_op_kws.keys()) - set(alpha_op_kws.keys())
|
|
all_ops = set(rgb_op_kws.keys()) | set(alpha_op_kws.keys())
|
|
|
|
alpha_only_masks = set(alpha_masks.keys()) - set(rgb_masks.keys())
|
|
rgb_only_masks = set(rgb_masks.keys()) - set(alpha_masks.keys())
|
|
all_masks = set(rgb_masks.keys()) | set(alpha_masks.keys())
|
|
|
|
def infer_operation_units(operations):
|
|
if len(operations) > 2:
|
|
raise ValidatorError("too many operations in instruction", operations[-1].opcode_keyword)
|
|
|
|
units = [None, None]
|
|
for i, operation in enumerate(operations):
|
|
opcode = operation.opcode_keyword.keyword
|
|
if opcode not in all_ops:
|
|
raise ValidatorError(f"invalid opcode keyword, expected one of {all_ops}", operation.opcode_keyword)
|
|
|
|
if len(operation.dest_addr_swizzles) > 2:
|
|
raise ValidationError("too many destinations in instruction", operation.dest_addr_swizzles[-1])
|
|
|
|
masks = set(prevalidate_mask(dest_addr_swizzle, all_masks) for dest_addr_swizzle in operation.dest_addr_swizzles)
|
|
|
|
def infer_opcode_unit():
|
|
if opcode in alpha_only_ops:
|
|
return Unit.alpha
|
|
if opcode in rgb_only_ops:
|
|
return Unit.rgb
|
|
return None
|
|
|
|
def infer_mask_unit():
|
|
if any(mask in alpha_only_masks for mask in masks):
|
|
return Unit.alpha
|
|
if any(mask in rgb_only_masks for mask in masks):
|
|
return Unit.rgb
|
|
return None
|
|
|
|
opcode_unit = infer_opcode_unit()
|
|
mask_unit = infer_mask_unit()
|
|
if opcode_unit is not None and mask_unit is not None and opcode_unit != mask_unit:
|
|
raise ValidatorError(f"contradictory {mask_unit.name} write mask for {opcode_unit.name} opcode", operation.opcode_keyword)
|
|
units[i] = opcode_unit or mask_unit
|
|
|
|
if units[0] == units[1]:
|
|
raise ValidatorError(f"invalid duplicate use of {units[1].name} unit", operations[1].opcode_keyword)
|
|
|
|
other_unit = {
|
|
Unit.alpha: Unit.rgb,
|
|
Unit.rgb: Unit.alpha,
|
|
}
|
|
|
|
if units[0] is None:
|
|
units[0] = other_unit[units[1]]
|
|
if units[1] is None:
|
|
units[1] = other_unit[units[0]]
|
|
|
|
assert units[0] is not None
|
|
assert units[1] is not None
|
|
assert units[0] != units[1]
|
|
|
|
for i, operation in enumerate(operations):
|
|
yield units[i], operation
|
|
|
|
def validate_instruction_operation_dest(dest_addr_swizzles, mask_lookup, type_cls):
|
|
addrd = None
|
|
target = None
|
|
wmask = None
|
|
omask = None
|
|
for dest_addr_swizzle in dest_addr_swizzles:
|
|
dest = validate_dest_keyword(dest_addr_swizzle.dest_keyword)
|
|
addr = validate_identifier_number(dest_addr_swizzle.addr_identifier)
|
|
mask = mask_lookup[dest_addr_swizzle.swizzle_identifier.lexeme.lower()]
|
|
if dest == KW.OUT:
|
|
omask = mask
|
|
target = addr
|
|
elif dest == KW.TEMP:
|
|
wmask = mask
|
|
addrd = addr
|
|
else:
|
|
assert False, dest
|
|
|
|
return type_cls(
|
|
addrd=addrd,
|
|
target=target,
|
|
wmask=wmask,
|
|
omask=omask
|
|
)
|
|
|
|
swizzle_sel_src_kws = OrderedDict([
|
|
(KW.SRC0, SwizzleSelSrc.src0),
|
|
(KW.SRC1, SwizzleSelSrc.src1),
|
|
(KW.SRC2, SwizzleSelSrc.src2),
|
|
(KW.SRCP, SwizzleSelSrc.srcp),
|
|
])
|
|
|
|
swizzle_kws = OrderedDict([
|
|
(ord("r"), Swizzle.r),
|
|
(ord("g"), Swizzle.g),
|
|
(ord("b"), Swizzle.b),
|
|
(ord("a"), Swizzle.a),
|
|
(ord("0"), Swizzle.zero),
|
|
(ord("h"), Swizzle.half),
|
|
(ord("1"), Swizzle.one),
|
|
(ord("_"), Swizzle.unused),
|
|
])
|
|
|
|
def validate_instruction_operation_sels(swizzle_sels, is_alpha):
|
|
if len(swizzle_sels) > 3:
|
|
raise ValidatorError("too many swizzle sels", swizzle_sels[-1].sel_keyword)
|
|
|
|
sels = []
|
|
for swizzle_sel in swizzle_sels:
|
|
if swizzle_sel.sel_keyword.keyword not in swizzle_sel_src_kws:
|
|
raise ValidatorError("invalid swizzle src", swizzle_sel.sel_keyword.keyword)
|
|
src = swizzle_sel_src_kws[swizzle_sel.sel_keyword.keyword]
|
|
|
|
swizzle_lexeme = swizzle_sel.swizzle_identifier.lexeme.lower()
|
|
swizzles_length = 1 if is_alpha else 3
|
|
if len(swizzle_lexeme) != swizzles_length:
|
|
raise ValidatorError("invalid swizzle length", swizzle_sel.swizzle_identifier)
|
|
if not all(c in swizzle_kws for c in swizzle_lexeme):
|
|
raise ValidatorError("invalid swizzle characters", swizzle_sel.swizzle_identifier)
|
|
swizzle = [
|
|
swizzle_kws[c] for c in swizzle_lexeme
|
|
]
|
|
|
|
mod = swizzle_sel.mod
|
|
sels.append(SwizzleSel(src, swizzle, mod))
|
|
return sels
|
|
|
|
def validate_alpha_instruction_operation(operation):
|
|
dest = validate_instruction_operation_dest(operation.dest_addr_swizzles,
|
|
mask_lookup=alpha_masks,
|
|
type_cls=AlphaDest)
|
|
opcode = alpha_op_kws[operation.opcode_keyword.keyword]
|
|
sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=True)
|
|
return AlphaOperation(
|
|
dest,
|
|
opcode,
|
|
sels
|
|
)
|
|
|
|
def validate_rgb_instruction_operation(operation):
|
|
dest = validate_instruction_operation_dest(operation.dest_addr_swizzles,
|
|
mask_lookup=rgb_masks,
|
|
type_cls=RGBDest)
|
|
opcode = rgb_op_kws[operation.opcode_keyword.keyword]
|
|
sels = validate_instruction_operation_sels(operation.swizzle_sels, is_alpha=False)
|
|
return RGBOperation(
|
|
dest,
|
|
opcode,
|
|
sels
|
|
)
|
|
|
|
def validate_instruction_operations(operations):
|
|
for unit, operation in infer_operation_units(operations):
|
|
if unit is Unit.alpha:
|
|
yield validate_alpha_instruction_operation(operation)
|
|
elif unit is Unit.rgb:
|
|
yield validate_rgb_instruction_operation(operation)
|
|
else:
|
|
assert False, unit
|
|
|
|
def validate_instruction(ins):
|
|
addr_rgb_alpha = validate_instruction_let_expressions(ins.let_expressions)
|
|
|
|
tags = set([tag.keyword for tag in ins.tags])
|
|
|
|
instruction = Instruction(
|
|
tags,
|
|
addr_rgb_alpha,
|
|
None,
|
|
None
|
|
)
|
|
for op in validate_instruction_operations(ins.operations):
|
|
if type(op) is RGBOperation:
|
|
instruction.rgb_op = op
|
|
elif type(op) is AlphaOperation:
|
|
instruction.alpha_op = op
|
|
else:
|
|
assert False, op
|
|
return instruction
|
|
|
|
if __name__ == "__main__":
|
|
from assembler.lexer import Lexer, LexerError
|
|
from assembler.fs.parser import Parser, ParserError
|
|
from assembler.fs.keywords import find_keyword
|
|
|
|
buf = b"""
|
|
src0.a = float(0), src0.rgb = temp[0] , srcp.a = neg :
|
|
out[0].none = temp[0].none = MAD src0.r src0.r src0.r ,
|
|
out[0].none = temp[0].r = DP3 src0.rg0 src0.rg0 ;
|
|
"""
|
|
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=True)
|
|
tokens = list(lexer.lex_tokens())
|
|
parser = Parser(tokens)
|
|
try:
|
|
ins_ast = parser.instruction()
|
|
pprint(validate_instruction(ins_ast))
|
|
except LexerError as e:
|
|
print_error(None, buf, e)
|
|
raise
|
|
except ParserError as e:
|
|
print_error(None, buf, e)
|
|
raise
|
|
except ValidatorError as e:
|
|
print_error(None, buf, e)
|
|
raise
|