regs: improve frontend for assembler and disassembler

This commit is contained in:
Zack Buhman 2025-10-16 17:17:16 -05:00
parent d903115964
commit be27c747ee
5 changed files with 151 additions and 160 deletions

View File

@ -1,5 +1,7 @@
from assembler.lexer import Lexer import sys
from assembler.parser import Parser
from assembler.lexer import Lexer, LexerError
from assembler.parser import Parser, ParserError
from assembler.emitter import emit_instruction from assembler.emitter import emit_instruction
sample = b""" sample = b"""
@ -15,15 +17,59 @@ out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1 out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
""" """
if __name__ == "__main__": def frontend_inner(buf):
#buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
buf = sample
lexer = Lexer(buf) lexer = Lexer(buf)
tokens = list(lexer.lex_tokens()) tokens = list(lexer.lex_tokens())
parser = Parser(tokens) parser = Parser(tokens)
for ins in parser.instructions(): for ins, start_end in parser.instructions():
print("\n".join( yield list(emit_instruction(ins)), start_end
f"{value:08x}"
for value in emit_instruction(ins) def print_error(filename, buf, e):
)) assert len(e.args) == 2, e
message, token = e.args
lines = buf.splitlines()
line = lines[token.line - 1]
error_name = str(type(e).__name__)
col_indent = ' ' * token.col
col_pointer = '^' * len(token.lexeme)
RED = "\033[0;31m"
DEFAULT = "\033[0;0m"
print(f'File: "{filename}", line {token.line}, column {token.col}\n', file=sys.stderr)
sys.stderr.write(' ')
wrote_default = False
for i, c in enumerate(line.decode('utf-8')):
if i == token.col:
sys.stderr.write(RED)
sys.stderr.write(c)
if i == token.col + len(token.lexeme):
wrote_default = True
sys.stderr.write(DEFAULT)
if not wrote_default:
sys.stderr.write(DEFAULT)
sys.stderr.write('\n')
print(f" {RED}{col_indent}{col_pointer}{DEFAULT}", file=sys.stderr)
print(f'{RED}{error_name}{DEFAULT}: {message}', file=sys.stderr)
def frontend(filename, buf):
try:
yield from frontend_inner(buf)
except ParserError as e:
print_error(input_filename, buf, e)
raise
except LexerError as e:
print_error(input_filename, buf, e)
raise
if __name__ == "__main__":
input_filename = sys.argv[1]
#output_filename = sys.argv[2]
with open(input_filename, 'rb') as f:
buf = f.read()
output = list(frontend(input_filename, buf))
for cw, (start_ix, end_ix) in output:
if True:
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},")
else:
source = buf[start_ix:end_ix]
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x}, // {source.decode('utf-8')}")

View File

@ -60,7 +60,8 @@ def src_reg_type(kw):
else: else:
assert not "Invalid PVS_SRC_REG", kw assert not "Invalid PVS_SRC_REG", kw
def emit_source(src: Source): def emit_source(src: Source, prev: Source):
if src is not None:
value = ( value = (
pvs_src.REG_TYPE_gen(src_reg_type(src.type)) pvs_src.REG_TYPE_gen(src_reg_type(src.type))
| pvs_src.OFFSET_gen(src.offset) | pvs_src.OFFSET_gen(src.offset)
@ -73,10 +74,39 @@ def emit_source(src: Source):
| pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2])) | pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2]))
| pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3])) | pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3]))
) )
else:
assert prev is not None
value = (
pvs_src.REG_TYPE_gen(src_reg_type(prev.type))
| pvs_src.OFFSET_gen(prev.offset)
| pvs_src.SWIZZLE_X_gen(7)
| pvs_src.SWIZZLE_Y_gen(7)
| pvs_src.SWIZZLE_Z_gen(7)
| pvs_src.SWIZZLE_W_gen(7)
| pvs_src.MODIFIER_X_gen(0)
| pvs_src.MODIFIER_Y_gen(0)
| pvs_src.MODIFIER_Z_gen(0)
| pvs_src.MODIFIER_W_gen(0)
)
yield value yield value
def prev_source(ins, ix):
if ix == 0:
assert ins.source0 is not None
return ins.source0
elif ix == 1:
return ins.source0
elif ix == 2:
if ins.source1 is not None:
return ins.source1
else:
return ins.source0
else:
assert False, ix
def emit_instruction(ins: Instruction): def emit_instruction(ins: Instruction):
yield from emit_destination_op(ins.destination_op) yield from emit_destination_op(ins.destination_op)
yield from emit_source(ins.source0)
yield from emit_source(ins.source1) yield from emit_source(ins.source0, prev_source(ins, 0))
yield from emit_source(ins.source2) yield from emit_source(ins.source1, prev_source(ins, 1))
yield from emit_source(ins.source2, prev_source(ins, 2))

View File

@ -21,6 +21,7 @@ class TT(Enum):
@dataclass @dataclass
class Token: class Token:
start_ix: int
line: int line: int
col: int col: int
type: TT type: TT
@ -64,7 +65,7 @@ class Lexer:
return self.buf[self.current_ix] return self.buf[self.current_ix]
def pos(self): def pos(self):
return self.line, self.col - (self.current_ix - self.start_ix) return self.start_ix, self.line, self.col - (self.current_ix - self.start_ix)
def identifier(self): def identifier(self):
while not self.at_end_p() and self.peek() in identifier_characters: while not self.at_end_p() and self.peek() in identifier_characters:
@ -96,7 +97,7 @@ class Lexer:
elif c == ord('.'): elif c == ord('.'):
return Token(*self.pos(), TT.dot, self.lexeme()) return Token(*self.pos(), TT.dot, self.lexeme())
elif c == ord(';'): elif c == ord(';'):
while not at_end_p() and peek() != ord('\n'): while not self.at_end_p() and self.peek() != ord('\n'):
self.advance() self.advance()
elif c == ord(' ') or c == ord('\r') or c == ord('\t'): elif c == ord(' ') or c == ord('\r') or c == ord('\t'):
pass pass

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import Union from typing import Union
from assembler import lexer from assembler import lexer
from assembler.lexer import TT from assembler.lexer import TT, Token
from assembler.keywords import KW, ME, VE from assembler.keywords import KW, ME, VE
""" """
@ -44,7 +44,7 @@ class Instruction:
source1: Source source1: Source
source2: Source source2: Source
class ParseError(Exception): class ParserError(Exception):
pass pass
def identifier_to_number(token): def identifier_to_number(token):
@ -52,7 +52,7 @@ def identifier_to_number(token):
assert token.type is TT.identifier assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme): if not all(d in digits for d in token.lexeme):
raise ParseError("expected number", token) raise ParserError("expected number", token)
return int(bytes(token.lexeme), 10) return int(bytes(token.lexeme), 10)
def we_ord(c): def we_ord(c):
@ -66,9 +66,9 @@ def parse_dest_write_enable(token):
assert token.type is TT.identifier assert token.type is TT.identifier
we = bytes(token.lexeme).lower() we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we): if not all(c in we_chars for c in we):
raise ParseError("expected destination write enable", token) raise ParserError("expected destination write enable", token)
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we): if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParseError("misleading non-sequential write enable", token) raise ParserError("misleading non-sequential write enable", token)
return set(we_ord(c) for c in we) return set(we_ord(c) for c in we)
def parse_source_swizzle(token): def parse_source_swizzle(token):
@ -89,25 +89,25 @@ def parse_source_swizzle(token):
swizzle_modifiers = [None] * 4 swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower() lexeme = bytes(token.lexeme).lower()
while state < 4: while state < 4:
if ix > len(token.lexeme): if ix >= len(token.lexeme):
raise ParseError("invalid source swizzle", token) raise ParserError("invalid source swizzle", token)
c = lexeme[ix] c = lexeme[ix]
if c == ord('-'): if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None): if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParseError("invalid source swizzle modifier", token) raise ParserError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True swizzle_modifiers[state] = True
elif c in select_mapping: elif c in select_mapping:
if swizzle_selects[state] is not None: if swizzle_selects[state] is not None:
raise ParseError("invalid source swizzle select", token) raise ParserError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c] swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None: if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False swizzle_modifiers[state] = False
state += 1 state += 1
else: else:
raise ParseError("invalid source swizzle", token) raise ParserError("invalid source swizzle", token)
ix += 1 ix += 1
if ix != len(lexeme): if ix != len(lexeme):
raise ParseError("invalid source swizzle", token) raise ParserError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers) return SourceSwizzle(swizzle_selects, swizzle_modifiers)
class Parser: class Parser:
@ -115,8 +115,8 @@ class Parser:
self.current_ix = 0 self.current_ix = 0
self.tokens = tokens self.tokens = tokens
def peek(self): def peek(self, offset=0):
token = self.tokens[self.current_ix] token = self.tokens[self.current_ix + offset]
#print(token) #print(token)
return token return token
@ -135,20 +135,20 @@ class Parser:
def consume(self, token_type, message): def consume(self, token_type, message):
token = self.advance() token = self.advance()
if token.type != token_type: if token.type != token_type:
raise ParseError(message, token) raise ParserError(message, token)
return token return token
def consume_either(self, token_type1, token_type2, message): def consume_either(self, token_type1, token_type2, message):
token = self.advance() token = self.advance()
if token.type != token_type1 and token.type != token_type2: if token.type != token_type1 and token.type != token_type2:
raise ParseError(message, token) raise ParserError(message, token)
return token return token
def destination_type(self): def destination_type(self):
token = self.consume(TT.keyword, "expected destination type") token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
if token.keyword not in destination_keywords: if token.keyword not in destination_keywords:
raise ParseError("expected destination type", token) raise ParserError("expected destination type", token)
return token.keyword return token.keyword
def offset(self): def offset(self):
@ -161,7 +161,7 @@ class Parser:
def opcode(self): def opcode(self):
token = self.consume(TT.keyword, "expected opcode") token = self.consume(TT.keyword, "expected opcode")
if type(token.keyword) != VE and type(token.keyword) != ME: if type(token.keyword) != VE and type(token.keyword) != ME:
raise ParseError("expected opcode", token) raise ParserError("expected opcode", token)
return token.keyword return token.keyword
def destination_op(self): def destination_op(self):
@ -178,7 +178,7 @@ class Parser:
token = self.consume(TT.keyword, "expected source type") token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary} source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords: if token.keyword not in source_keywords:
raise ParseError("expected source type", token) raise ParserError("expected source type", token)
return token.keyword return token.keyword
def source_swizzle(self): def source_swizzle(self):
@ -196,12 +196,23 @@ class Parser:
def instruction(self): def instruction(self):
while self.match(TT.eol): while self.match(TT.eol):
self.advance() self.advance()
first_token = self.peek()
destination_op = self.destination_op() destination_op = self.destination_op()
source0 = self.source() source0 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source1 = None
else:
source1 = self.source() source1 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source2 = None
else:
source2 = self.source() source2 = self.source()
last_token = self.peek(-1)
self.consume_either(TT.eol, TT.eof, "expected newline or EOF") self.consume_either(TT.eol, TT.eof, "expected newline or EOF")
return Instruction(destination_op, source0, source1, source2) return (
Instruction(destination_op, source0, source1, source2),
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
)
def instructions(self): def instructions(self):
while not self.match(TT.eof): while not self.match(TT.eof):

View File

@ -2,113 +2,9 @@ import pvs_src
import pvs_src_bits import pvs_src_bits
import pvs_dst import pvs_dst
import pvs_dst_bits import pvs_dst_bits
from pprint import pprint
import itertools import itertools
from functools import partial from functools import partial
import sys
code = [
0x00f00203,
0x00d10001,
0x01248001,
0x01248001,
]
# Radeon Compiler Program
# 0: MOV output[1].xyz, input[1].xyz_;
# 1: MOV output[0], input[0].xyz1;
# Final vertex program code:
# 0: op: 0x00702203 dst: 1o op: VE_ADD
# src0: 0x01d10021 reg: 1i swiz: X/ Y/ Z/ U
# src1: 0x01248021 reg: 1i swiz: 0/ 0/ 0/ 0
# src2: 0x01248021 reg: 1i swiz: 0/ 0/ 0/ 0
# 1: op: 0x00f00203 dst: 0o op: VE_ADD
# src0: 0x01510001 reg: 0i swiz: X/ Y/ Z/ 1
# src1: 0x01248001 reg: 0i swiz: 0/ 0/ 0/ 0
# src2: 0x01248001 reg: 0i swiz: 0/ 0/ 0/ 0
code = [
0x00702203,
0x01d10021,
0x01248021,
0x01248021,
0x00f00203,
0x01510001,
0x01248001,
0x01248001,
]
code = [
0x00f00003,
0x00d10022,
0x01248022,
0x01248022,
0x00f02003,
0x00d10022,
0x01248022,
0x01248022,
0x00100004,
0x01ff0002,
0x01ff0020,
0x01ff2000,
0x00100006,
0x01ff0000,
0x01248000,
0x01248000,
0x00100004,
0x01ff0000,
0x01ff4022,
0x01ff6022,
0x00100050,
0x00000000,
0x01248000,
0x01248000,
0x00f00204,
0x0165a000,
0x01690001,
0x01240000,
]
code = [
0x00f00003,
0x00d10022,
0x01248022,
0x01248022,
0x00f02003,
0x00d10022,
0x01248022,
0x01248022,
0x00100004,
0x01ff0002,
0x01ff0020,
0x01ff2000,
0x00100006,
0x01ff0000,
0x01248000,
0x01248000,
0x00100004,
0x01ff0000,
0x01ff4022,
0x01ff6022,
0x00200051,
0x00000000,
0x01248000,
0x01248000,
0x00100050,
0x00000000,
0x01248000,
0x01248000,
0x00600002,
0x01c8e001,
0x01c9e000,
0x01248000,
0x00500204,
0x1fe72001,
0x01e70000,
0x01e72000,
0x00a00204,
0x0138e001,
0x0138e000,
0x017ae000,
]
def out(level, *args): def out(level, *args):
sys.stdout.write(" " * level + " ".join(args)) sys.stdout.write(" " * level + " ".join(args))
@ -151,8 +47,6 @@ def parse_code(code):
ix += 4 ix += 4
#parse_code(code)
def dst_swizzle_from_we(dst_op): def dst_swizzle_from_we(dst_op):
table = [ table = [
(pvs_dst.WE_X, "x"), (pvs_dst.WE_X, "x"),
@ -168,7 +62,7 @@ def dst_swizzle_from_we(dst_op):
_op_substitutions = [ _op_substitutions = [
("DOT_PRODUCT", "DOT"), ("DOT_PRODUCT", "DOT"),
("MULTIPLY_ADD", "MAD"), ("MULTIPLY_ADD", "MAD"),
("FRACTION", "FRAC"), ("FRACTION", "FRC"),
("MULTIPLY", "MUL"), ("MULTIPLY", "MUL"),
("MAXMIUM", "MAX"), ("MAXMIUM", "MAX"),
("MINIMUM", "MIN"), ("MINIMUM", "MIN"),
@ -280,5 +174,14 @@ def parse_instruction(instruction):
print(dst.ljust(12), "=", op.ljust(9), " ".join(map(lambda s: s.ljust(17), rest))) print(dst.ljust(12), "=", op.ljust(9), " ".join(map(lambda s: s.ljust(17), rest)))
def parse_hex(s):
assert s.startswith('0x')
return int(s.removeprefix('0x'), 16)
if __name__ == "__main__":
filename = sys.argv[1]
with open(filename) as f:
buf = f.read()
code = [parse_hex(c.strip()) for c in buf.split(',') if c.strip()]
for i in range(len(code) // 4): for i in range(len(code) // 4):
parse_instruction(code[i*4:i*4+4]) parse_instruction(code[i*4:i*4+4])