r500/regs/assembler/vs/validator.py

415 lines
15 KiB
Python

from dataclasses import dataclass
from enum import IntEnum
from itertools import pairwise
from typing import Union
from assembler.lexer import Token
from assembler.validator import ValidatorError
from assembler.vs.keywords import KW
from assembler.vs import opcodes
from collections import OrderedDict
class SourceType(IntEnum):
temporary = 0
input = 1
constant = 2
alt_temporary = 3
class SwizzleSelect(IntEnum):
x = 0
y = 1
z = 2
w = 3
zero = 4
one = 5
half = 6
unused = 7
@dataclass
class Source:
type: SourceType
absolute: bool
offset: int
swizzle_selects: tuple[SwizzleSelect, SwizzleSelect, SwizzleSelect, SwizzleSelect]
modifiers: tuple[bool, bool, bool, bool]
class DestinationType(IntEnum):
temporary = 0
a0 = 1
out = 2
out_repl_x = 3
alt_temporary = 4
input = 5
@dataclass
class Destination:
type: DestinationType
offset: int
write_enable: tuple[bool, bool, bool, bool]
@dataclass
class Instruction:
destination: Destination
saturation: bool
sources: list[Source]
opcode: Union[opcodes.VE, opcodes.ME, opcodes.MVE]
@dataclass
class DualMathVEOperation:
destination: Destination
saturation: bool
sources: list[Source]
opcode: opcodes.VE
class DualMathMEWriteEnable(IntEnum):
x = 0
y = 1
z = 2
w = 3
@dataclass
class DualMathMEDestination:
offset: int
write_enable: DualMathMEWriteEnable
@dataclass
class DualMathMESource:
type: SourceType
absolute: bool
offset: int
swizzle_selects: tuple[SwizzleSelect, SwizzleSelect]
modifiers: tuple[bool, bool]
@dataclass
class DualMathMEOperation:
destination: Destination
saturation: bool
sources: list[DualMathMESource]
opcode: opcodes.ME
@dataclass
class DualMathInstruction:
ve_operation: DualMathVEOperation
me_operation: DualMathMEOperation
def validate_opcode(opcode_keyword: Token):
if type(opcode_keyword.keyword) is opcodes.ME:
return opcode_keyword.keyword
elif type(opcode_keyword.keyword) is opcodes.VE:
return opcode_keyword.keyword
else:
raise ValidatorError("invalid opcode keyword", opcode_keyword)
def validate_identifier_number(token):
try:
return int(token.lexeme, 10)
except ValueError:
raise ValidatorError("invalid number", token)
def validate_destination_write_enable(write_enable_identifier):
we_chars_s = b"xyzw"
we_chars = {c: i for i, c in enumerate(we_chars_s)}
we = bytes(write_enable_identifier.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParserError("invalid character in destination write enable", write_enable_identifier)
if not all(we_chars[a] < we_chars[b] for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParserError("misleading non-sequential destination write enable", write_enable_identifier)
we = set(we)
return tuple(c in we for c in we_chars_s)
def validate_destination(destination):
destination_type_keywords = OrderedDict([
(KW.temporary , (DestinationType.temporary , 128)), # 32
(KW.a0 , (DestinationType.a0 , 1 )), # ??
(KW.out , (DestinationType.out , 128)), # ?
(KW.out_repl_x , (DestinationType.out_repl_x , 128)), # ?
(KW.alt_temporary , (DestinationType.alt_temporary, 20 )), # 20
(KW.input , (DestinationType.input , 128)), # 32
])
if destination.type_keyword.keyword not in destination_type_keywords:
raise ValidatorError("invalid destination type keyword", destination.type_keyword)
type, max_offset = destination_type_keywords[destination.type_keyword.keyword]
offset = validate_identifier_number(destination.offset_identifier)
if offset >= max_offset:
raise ValidatorError("invalid offset value", source.offset_identifier)
write_enable = validate_destination_write_enable(destination.write_enable_identifier)
return Destination(
type,
offset,
write_enable
)
def parse_swizzle_lexeme(token):
swizzle_select_characters = OrderedDict([
(ord(b"x"), SwizzleSelect.x),
(ord(b"y"), SwizzleSelect.y),
(ord(b"z"), SwizzleSelect.z),
(ord(b"w"), SwizzleSelect.w),
(ord(b"0"), SwizzleSelect.zero),
(ord(b"1"), SwizzleSelect.one),
(ord(b"h"), SwizzleSelect.half),
(ord(b"_"), SwizzleSelect.unused),
])
lexeme = bytes(token.lexeme).lower()
swizzles = []
modifier = False
for c in lexeme:
if c == ord(b"-"):
modifier = True
else:
if c not in swizzle_select_characters:
raise ValueError(c)
swizzles.append((swizzle_select_characters[c], modifier))
modifier = False
return tuple(zip(*swizzles))
def validate_source(source, swizzle_select_length):
source_type_keywords = OrderedDict([
(KW.temporary , (SourceType.temporary , 128)), # 32
(KW.input , (SourceType.input , 128)), # 32
(KW.constant , (SourceType.constant , 256)), # 256
(KW.alt_temporary , (SourceType.alt_temporary, 20)), # 20
])
if source.type_keyword.keyword not in source_type_keywords:
raise ValidatorError("invalid source type keyword", source.type_keyword)
type, max_offset = source_type_keywords[source.type_keyword.keyword]
absolute = source.absolute
offset = validate_identifier_number(source.offset_identifier)
if offset >= max_offset:
raise ValidatorError("invalid offset value", source.offset_identifier)
try:
swizzle_selects, modifiers = parse_swizzle_lexeme(source.swizzle_identifier)
except ValueError:
raise ValidatorError("invalid source swizzle", source.swizzle_identifier)
assert len(swizzle_selects) == len(modifiers)
if len(swizzle_selects) != swizzle_select_length:
raise ValidatorError("invalid source swizzle", source.swizzle_identifier)
return Source(
type,
absolute,
offset,
swizzle_selects,
modifiers
)
def addresses_by_type(sources, source_type):
return set(int(source.offset)
for source in sources
if source.type == source_type)
def source_ix_with_type_reversed(sources, source_type):
for i, source in reversed(list(enumerate(sources))):
if source.type == source_type:
return i
assert False, (sources, source_type)
def validate_source_address_counts(sources_ast, sources, opcode):
temporary_address_count = len(addresses_by_type(sources, SourceType.temporary))
assert temporary_address_count >= 0 and temporary_address_count <= 3
assert type(opcode) in {opcodes.VE, opcodes.ME}
if temporary_address_count == 3:
if opcode == opcodes.VE_MULTIPLY_ADD:
opcode = opcodes.MACRO_OP_2CLK_MADD
elif opcode == opcodes.VE_MULTIPLYX2_ADD:
opcode = opcodes.MACRO_OP_2CLK_M2X_ADD
else:
raise ValidatorError("too many temporary addresses in non-macro operation(s)",
sources_ast[-1].offset_identifier)
constant_count = len(addresses_by_type(sources, SourceType.constant))
if constant_count > 1:
source_ix = source_ix_with_type_reversed(sources, SourceType.constant)
raise ValidatorError(f"too many constant addresses in operation(s); expected 1, have {constant_count}",
sources_ast[source_ix].offset_identifier)
input_count = len(addresses_by_type(sources, SourceType.input))
if input_count > 1:
source_ix = source_with_type_reversed(sources, SourceType.input)
raise ValidatorError(f"too many input addresses in operation(s); expected 1, have {input_count}",
sources_ast[source_ix].offset_identifier)
alt_temporary_count = len(addresses_by_type(sources, SourceType.alt_temporary))
if alt_temporary_count > 1:
source_ix = source_with_type_reversed(sources, SourceType.alt_temporary)
raise ValidatorError(f"too many alt temporary addresses in operation(s); expected 1, have {alt_temporary_count}",
sources_ast[source_ix].offset_identifier)
return opcode
def validate_instruction_inner(operation, opcode):
destination = validate_destination(operation.destination)
saturation = False
if operation.opcode_suffix_keyword is not None:
if operation.opcode_suffix_keyword.keyword is not KW.saturation:
raise ValidatorError("invalid opcode saturation suffix", operation.opcode_suffix_keyword)
saturation = True
if len(operation.sources) > 3:
raise ValidatorError("too many sources in operation", operation.sources[-1].type_keyword)
if len(operation.sources) != opcode.operand_count:
raise ValidatorError(f"incorrect number of source operands; expected {opcode.operand_count}", operation.sources[0].type_keyword)
sources = []
for source in operation.sources:
sources.append(validate_source(source, swizzle_select_length=4))
opcode = validate_source_address_counts(operation.sources, sources, opcode)
return Instruction(
destination,
saturation,
sources,
opcode
)
def validate_dual_math_ve_operation(operation, opcode):
destination = validate_destination(operation.destination)
saturation = False
if operation.opcode_suffix_keyword is not None:
if operation.opcode_suffix_keyword.keyword is not KW.saturation:
raise ValidatorError("invalid opcode saturation suffix", operation.opcode_suffix_keyword)
saturation = True
if len(operation.sources) > 2:
raise ValidatorError("too many sources in dual math VE operation", operation.sources[-1].type_keyword)
if opcode.operand_count > 2:
raise ValidatorError("3-operand opcode not valid in dual math VE operation", operation.sources[-1].type_keyword)
if len(operation.sources) != opcode.operand_count:
raise ValidatorError(f"incorrect number of source operands; expected {opcode.operand_count}", operation.sources[0].type_keyword)
sources = []
for source in operation.sources:
sources.append(validate_source(source, swizzle_select_length=4))
return DualMathVEOperation(
destination,
saturation,
sources,
opcode
)
def validate_dual_math_me_destination(destination):
if destination.type_keyword.keyword is not KW.alt_temporary:
raise ValidatorError("invalid dual math ME destination type keyword", destination.type_keyword)
offset = validate_identifier_number(destination.offset_identifier)
if offset >= 4:
raise ValidatorError("invalid dual math ME offset value", source.offset_identifier)
we = bytes(destination.write_enable_identifier.lexeme).lower()
if len(we) != 1:
raise ValidatorError("invalid dual math ME write enable", destination.write_enable_identifier)
we_chars = {
c: t for c, t in zip(b"xyzw", [
DualMathMEWriteEnable.x,
DualMathMEWriteEnable.y,
DualMathMEWriteEnable.z,
DualMathMEWriteEnable.w,
])
}
we_char = we[0]
if we_char not in we_chars:
ParserError("invalid dual math ME write enable", destination.write_enable_identifier)
write_enable = we_chars[we[0]]
return DualMathMEDestination(
offset,
write_enable,
)
def validate_dual_math_me_operation(operation, opcode):
destination = validate_dual_math_me_destination(operation.destination)
saturation = False
if operation.opcode_suffix_keyword is not None:
if operation.opcode_suffix_keyword.keyword is not KW.saturation:
raise ValidatorError("invalid opcode saturation suffix", operation.opcode_suffix_keyword)
saturation = True
if len(operation.sources) > 1:
raise ValidatorError("too many sources in dual math ME operation", operation.sources[-1].type_keyword)
if len(operation.sources) != opcode.operand_count:
raise ValidatorError(f"incorrect number of source operands; expected {opcode.operand_count}", operation.sources[0].type_keyword)
sources = []
for source in operation.sources:
sources.append(validate_source(source, swizzle_select_length=2))
return DualMathMEOperation(
destination,
saturation,
sources,
opcode,
)
def validate_dual_math_instruction(operations, _opcodes):
if type(_opcodes[0]) is opcodes.VE:
ve_operation_ast = operations[0]
ve_opcode = _opcodes[0]
me_operation_ast = operations[1]
me_opcode = _opcodes[1]
else:
ve_operation_ast = operations[1]
ve_opcode = _opcodes[1]
me_operation_ast = operations[0]
me_opcode = _opcodes[0]
ve_operation = validate_dual_math_ve_operation(ve_operation_ast, ve_opcode)
me_operation = validate_dual_math_me_operation(me_operation_ast, me_opcode)
all_sources_ast = ve_operation_ast.sources + me_operation_ast.sources
all_sources = ve_operation.sources + me_operation.sources
validate_opcode = validate_source_address_counts(all_sources_ast, all_sources, ve_operation.opcode)
assert validate_opcode == ve_operation.opcode
return DualMathInstruction(
ve_operation,
me_operation,
)
def validate_instruction(ins):
if len(ins.operations) > 2:
raise ValidatorError("too many operations in instruction", ins.operations[0].destination.type_keyword)
opcodes = [validate_opcode(operation.opcode_keyword) for operation in ins.operations]
opcode_types = set(type(opcode) for opcode in opcodes)
if len(opcode_types) != len(opcodes):
opcode_type, = opcode_types
raise ValidatorError(f"invalid dual math operation: too many opcodes of type {opcode_type}", ins.operations[0].opcode_keyword)
if len(opcodes) == 2:
return validate_dual_math_instruction(ins.operations, opcodes)
else:
assert len(opcodes) == 1
return validate_instruction_inner(ins.operations[0], opcodes[0])
if __name__ == "__main__":
from assembler.lexer import Lexer
from assembler.parser import ParserError
from assembler.vs.parser import Parser
from assembler.vs.keywords import find_keyword
from assembler.error import print_error
buf = b"""
out[0].xz = VE_MAD.SAT |temp[1].-y-_0-_| const[2].x_0_ const[2].x_0_ ;
"""
buf = b"""
out[0].xz = VE_MUL.SAT |temp[1].-y-_0-_| const[2].x_0_ ,
alt_temp[0].x = ME_SIN input[0].-y-_ ;
"""
lexer = Lexer(buf, find_keyword, emit_newlines=False, minus_is_token=False)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
from pprint import pprint
try:
ins = parser.instruction()
pprint(validate_instruction(ins))
except ValidatorError as e:
print_error(None, buf, e)
raise
except ParserError as e:
print_error(None, buf, e)
raise