From be27c747eec1d4761945410edfa8569ff33b7bd3 Mon Sep 17 00:00:00 2001 From: Zack Buhman Date: Thu, 16 Oct 2025 17:17:16 -0500 Subject: [PATCH] regs: improve frontend for assembler and disassembler --- regs/assembler/__main__.py | 68 ++++++++++++++++---- regs/assembler/emitter.py | 62 ++++++++++++++----- regs/assembler/lexer.py | 5 +- regs/assembler/parser.py | 53 +++++++++------- regs/pvs_disassemble.py | 123 ++++--------------------------------- 5 files changed, 151 insertions(+), 160 deletions(-) diff --git a/regs/assembler/__main__.py b/regs/assembler/__main__.py index 9825c29..4ed3963 100644 --- a/regs/assembler/__main__.py +++ b/regs/assembler/__main__.py @@ -1,5 +1,7 @@ -from assembler.lexer import Lexer -from assembler.parser import Parser +import sys + +from assembler.lexer import Lexer, LexerError +from assembler.parser import Parser, ParserError from assembler.emitter import emit_instruction sample = b""" @@ -15,15 +17,59 @@ out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_ out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1 """ -if __name__ == "__main__": - - #buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" - buf = sample +def frontend_inner(buf): lexer = Lexer(buf) tokens = list(lexer.lex_tokens()) parser = Parser(tokens) - for ins in parser.instructions(): - print("\n".join( - f"{value:08x}" - for value in emit_instruction(ins) - )) + for ins, start_end in parser.instructions(): + yield list(emit_instruction(ins)), start_end + +def print_error(filename, buf, e): + assert len(e.args) == 2, e + message, token = e.args + lines = buf.splitlines() + line = lines[token.line - 1] + + error_name = str(type(e).__name__) + col_indent = ' ' * token.col + col_pointer = '^' * len(token.lexeme) + RED = "\033[0;31m" + DEFAULT = "\033[0;0m" + print(f'File: "{filename}", line {token.line}, column {token.col}\n', file=sys.stderr) + sys.stderr.write(' ') + wrote_default = False + for i, c in enumerate(line.decode('utf-8')): + if i == token.col: + sys.stderr.write(RED) + sys.stderr.write(c) + if i == token.col + len(token.lexeme): + wrote_default = True + sys.stderr.write(DEFAULT) + if not wrote_default: + sys.stderr.write(DEFAULT) + sys.stderr.write('\n') + print(f" {RED}{col_indent}{col_pointer}{DEFAULT}", file=sys.stderr) + print(f'{RED}{error_name}{DEFAULT}: {message}', file=sys.stderr) + +def frontend(filename, buf): + try: + yield from frontend_inner(buf) + except ParserError as e: + print_error(input_filename, buf, e) + raise + except LexerError as e: + print_error(input_filename, buf, e) + raise + +if __name__ == "__main__": + input_filename = sys.argv[1] + #output_filename = sys.argv[2] + with open(input_filename, 'rb') as f: + buf = f.read() + output = list(frontend(input_filename, buf)) + for cw, (start_ix, end_ix) in output: + if True: + print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},") + else: + source = buf[start_ix:end_ix] + print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x}, // {source.decode('utf-8')}") diff --git a/regs/assembler/emitter.py b/regs/assembler/emitter.py index a81d84c..0e57c49 100644 --- a/regs/assembler/emitter.py +++ b/regs/assembler/emitter.py @@ -60,23 +60,53 @@ def src_reg_type(kw): else: assert not "Invalid PVS_SRC_REG", kw -def emit_source(src: Source): - value = ( - pvs_src.REG_TYPE_gen(src_reg_type(src.type)) - | pvs_src.OFFSET_gen(src.offset) - | pvs_src.SWIZZLE_X_gen(src.swizzle.select[0]) - | pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1]) - | pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2]) - | pvs_src.SWIZZLE_W_gen(src.swizzle.select[3]) - | pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0])) - | pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1])) - | pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2])) - | pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3])) - ) +def emit_source(src: Source, prev: Source): + if src is not None: + value = ( + pvs_src.REG_TYPE_gen(src_reg_type(src.type)) + | pvs_src.OFFSET_gen(src.offset) + | pvs_src.SWIZZLE_X_gen(src.swizzle.select[0]) + | pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1]) + | pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2]) + | pvs_src.SWIZZLE_W_gen(src.swizzle.select[3]) + | pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0])) + | pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1])) + | pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2])) + | pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3])) + ) + else: + assert prev is not None + value = ( + pvs_src.REG_TYPE_gen(src_reg_type(prev.type)) + | pvs_src.OFFSET_gen(prev.offset) + | pvs_src.SWIZZLE_X_gen(7) + | pvs_src.SWIZZLE_Y_gen(7) + | pvs_src.SWIZZLE_Z_gen(7) + | pvs_src.SWIZZLE_W_gen(7) + | pvs_src.MODIFIER_X_gen(0) + | pvs_src.MODIFIER_Y_gen(0) + | pvs_src.MODIFIER_Z_gen(0) + | pvs_src.MODIFIER_W_gen(0) + ) yield value +def prev_source(ins, ix): + if ix == 0: + assert ins.source0 is not None + return ins.source0 + elif ix == 1: + return ins.source0 + elif ix == 2: + if ins.source1 is not None: + return ins.source1 + else: + return ins.source0 + else: + assert False, ix + def emit_instruction(ins: Instruction): yield from emit_destination_op(ins.destination_op) - yield from emit_source(ins.source0) - yield from emit_source(ins.source1) - yield from emit_source(ins.source2) + + yield from emit_source(ins.source0, prev_source(ins, 0)) + yield from emit_source(ins.source1, prev_source(ins, 1)) + yield from emit_source(ins.source2, prev_source(ins, 2)) diff --git a/regs/assembler/lexer.py b/regs/assembler/lexer.py index c3523d1..122b6a0 100644 --- a/regs/assembler/lexer.py +++ b/regs/assembler/lexer.py @@ -21,6 +21,7 @@ class TT(Enum): @dataclass class Token: + start_ix: int line: int col: int type: TT @@ -64,7 +65,7 @@ class Lexer: return self.buf[self.current_ix] def pos(self): - return self.line, self.col - (self.current_ix - self.start_ix) + return self.start_ix, self.line, self.col - (self.current_ix - self.start_ix) def identifier(self): while not self.at_end_p() and self.peek() in identifier_characters: @@ -96,7 +97,7 @@ class Lexer: elif c == ord('.'): return Token(*self.pos(), TT.dot, self.lexeme()) elif c == ord(';'): - while not at_end_p() and peek() != ord('\n'): + while not self.at_end_p() and self.peek() != ord('\n'): self.advance() elif c == ord(' ') or c == ord('\r') or c == ord('\t'): pass diff --git a/regs/assembler/parser.py b/regs/assembler/parser.py index cf6805a..c08a90c 100644 --- a/regs/assembler/parser.py +++ b/regs/assembler/parser.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from typing import Union from assembler import lexer -from assembler.lexer import TT +from assembler.lexer import TT, Token from assembler.keywords import KW, ME, VE """ @@ -44,7 +44,7 @@ class Instruction: source1: Source source2: Source -class ParseError(Exception): +class ParserError(Exception): pass def identifier_to_number(token): @@ -52,7 +52,7 @@ def identifier_to_number(token): assert token.type is TT.identifier if not all(d in digits for d in token.lexeme): - raise ParseError("expected number", token) + raise ParserError("expected number", token) return int(bytes(token.lexeme), 10) def we_ord(c): @@ -66,9 +66,9 @@ def parse_dest_write_enable(token): assert token.type is TT.identifier we = bytes(token.lexeme).lower() if not all(c in we_chars for c in we): - raise ParseError("expected destination write enable", token) + raise ParserError("expected destination write enable", token) if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we): - raise ParseError("misleading non-sequential write enable", token) + raise ParserError("misleading non-sequential write enable", token) return set(we_ord(c) for c in we) def parse_source_swizzle(token): @@ -89,25 +89,25 @@ def parse_source_swizzle(token): swizzle_modifiers = [None] * 4 lexeme = bytes(token.lexeme).lower() while state < 4: - if ix > len(token.lexeme): - raise ParseError("invalid source swizzle", token) + if ix >= len(token.lexeme): + raise ParserError("invalid source swizzle", token) c = lexeme[ix] if c == ord('-'): if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None): - raise ParseError("invalid source swizzle modifier", token) + raise ParserError("invalid source swizzle modifier", token) swizzle_modifiers[state] = True elif c in select_mapping: if swizzle_selects[state] is not None: - raise ParseError("invalid source swizzle select", token) + raise ParserError("invalid source swizzle select", token) swizzle_selects[state] = select_mapping[c] if swizzle_modifiers[state] is None: swizzle_modifiers[state] = False state += 1 else: - raise ParseError("invalid source swizzle", token) + raise ParserError("invalid source swizzle", token) ix += 1 if ix != len(lexeme): - raise ParseError("invalid source swizzle", token) + raise ParserError("invalid source swizzle", token) return SourceSwizzle(swizzle_selects, swizzle_modifiers) class Parser: @@ -115,8 +115,8 @@ class Parser: self.current_ix = 0 self.tokens = tokens - def peek(self): - token = self.tokens[self.current_ix] + def peek(self, offset=0): + token = self.tokens[self.current_ix + offset] #print(token) return token @@ -135,20 +135,20 @@ class Parser: def consume(self, token_type, message): token = self.advance() if token.type != token_type: - raise ParseError(message, token) + raise ParserError(message, token) return token def consume_either(self, token_type1, token_type2, message): token = self.advance() if token.type != token_type1 and token.type != token_type2: - raise ParseError(message, token) + raise ParserError(message, token) return token def destination_type(self): token = self.consume(TT.keyword, "expected destination type") destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} if token.keyword not in destination_keywords: - raise ParseError("expected destination type", token) + raise ParserError("expected destination type", token) return token.keyword def offset(self): @@ -161,7 +161,7 @@ class Parser: def opcode(self): token = self.consume(TT.keyword, "expected opcode") if type(token.keyword) != VE and type(token.keyword) != ME: - raise ParseError("expected opcode", token) + raise ParserError("expected opcode", token) return token.keyword def destination_op(self): @@ -178,7 +178,7 @@ class Parser: token = self.consume(TT.keyword, "expected source type") source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary} if token.keyword not in source_keywords: - raise ParseError("expected source type", token) + raise ParserError("expected source type", token) return token.keyword def source_swizzle(self): @@ -196,12 +196,23 @@ class Parser: def instruction(self): while self.match(TT.eol): self.advance() + first_token = self.peek() destination_op = self.destination_op() source0 = self.source() - source1 = self.source() - source2 = self.source() + if self.match(TT.eol) or self.match(TT.eof): + source1 = None + else: + source1 = self.source() + if self.match(TT.eol) or self.match(TT.eof): + source2 = None + else: + source2 = self.source() + last_token = self.peek(-1) self.consume_either(TT.eol, TT.eof, "expected newline or EOF") - return Instruction(destination_op, source0, source1, source2) + return ( + Instruction(destination_op, source0, source1, source2), + (first_token.start_ix, last_token.start_ix + len(last_token.lexeme)) + ) def instructions(self): while not self.match(TT.eof): diff --git a/regs/pvs_disassemble.py b/regs/pvs_disassemble.py index 0fac1c3..9d18c87 100644 --- a/regs/pvs_disassemble.py +++ b/regs/pvs_disassemble.py @@ -2,113 +2,9 @@ import pvs_src import pvs_src_bits import pvs_dst import pvs_dst_bits -from pprint import pprint import itertools from functools import partial - -code = [ - 0x00f00203, - 0x00d10001, - 0x01248001, - 0x01248001, -] - -# Radeon Compiler Program -# 0: MOV output[1].xyz, input[1].xyz_; -# 1: MOV output[0], input[0].xyz1; -# Final vertex program code: -# 0: op: 0x00702203 dst: 1o op: VE_ADD -# src0: 0x01d10021 reg: 1i swiz: X/ Y/ Z/ U -# src1: 0x01248021 reg: 1i swiz: 0/ 0/ 0/ 0 -# src2: 0x01248021 reg: 1i swiz: 0/ 0/ 0/ 0 -# 1: op: 0x00f00203 dst: 0o op: VE_ADD -# src0: 0x01510001 reg: 0i swiz: X/ Y/ Z/ 1 -# src1: 0x01248001 reg: 0i swiz: 0/ 0/ 0/ 0 -# src2: 0x01248001 reg: 0i swiz: 0/ 0/ 0/ 0 -code = [ - 0x00702203, - 0x01d10021, - 0x01248021, - 0x01248021, - 0x00f00203, - 0x01510001, - 0x01248001, - 0x01248001, -] - -code = [ - 0x00f00003, - 0x00d10022, - 0x01248022, - 0x01248022, - 0x00f02003, - 0x00d10022, - 0x01248022, - 0x01248022, - 0x00100004, - 0x01ff0002, - 0x01ff0020, - 0x01ff2000, - 0x00100006, - 0x01ff0000, - 0x01248000, - 0x01248000, - 0x00100004, - 0x01ff0000, - 0x01ff4022, - 0x01ff6022, - 0x00100050, - 0x00000000, - 0x01248000, - 0x01248000, - 0x00f00204, - 0x0165a000, - 0x01690001, - 0x01240000, -] - -code = [ - 0x00f00003, - 0x00d10022, - 0x01248022, - 0x01248022, - 0x00f02003, - 0x00d10022, - 0x01248022, - 0x01248022, - 0x00100004, - 0x01ff0002, - 0x01ff0020, - 0x01ff2000, - 0x00100006, - 0x01ff0000, - 0x01248000, - 0x01248000, - 0x00100004, - 0x01ff0000, - 0x01ff4022, - 0x01ff6022, - 0x00200051, - 0x00000000, - 0x01248000, - 0x01248000, - 0x00100050, - 0x00000000, - 0x01248000, - 0x01248000, - 0x00600002, - 0x01c8e001, - 0x01c9e000, - 0x01248000, - 0x00500204, - 0x1fe72001, - 0x01e70000, - 0x01e72000, - 0x00a00204, - 0x0138e001, - 0x0138e000, - 0x017ae000, -] +import sys def out(level, *args): sys.stdout.write(" " * level + " ".join(args)) @@ -151,8 +47,6 @@ def parse_code(code): ix += 4 -#parse_code(code) - def dst_swizzle_from_we(dst_op): table = [ (pvs_dst.WE_X, "x"), @@ -168,7 +62,7 @@ def dst_swizzle_from_we(dst_op): _op_substitutions = [ ("DOT_PRODUCT", "DOT"), ("MULTIPLY_ADD", "MAD"), - ("FRACTION", "FRAC"), + ("FRACTION", "FRC"), ("MULTIPLY", "MUL"), ("MAXMIUM", "MAX"), ("MINIMUM", "MIN"), @@ -280,5 +174,14 @@ def parse_instruction(instruction): print(dst.ljust(12), "=", op.ljust(9), " ".join(map(lambda s: s.ljust(17), rest))) -for i in range(len(code) // 4): - parse_instruction(code[i*4:i*4+4]) +def parse_hex(s): + assert s.startswith('0x') + return int(s.removeprefix('0x'), 16) + +if __name__ == "__main__": + filename = sys.argv[1] + with open(filename) as f: + buf = f.read() + code = [parse_hex(c.strip()) for c in buf.split(',') if c.strip()] + for i in range(len(code) // 4): + parse_instruction(code[i*4:i*4+4])