regs: improve frontend for assembler and disassembler

This commit is contained in:
Zack Buhman 2025-10-16 17:17:16 -05:00
parent d903115964
commit be27c747ee
5 changed files with 151 additions and 160 deletions

View File

@ -1,5 +1,7 @@
from assembler.lexer import Lexer
from assembler.parser import Parser
import sys
from assembler.lexer import Lexer, LexerError
from assembler.parser import Parser, ParserError
from assembler.emitter import emit_instruction
sample = b"""
@ -15,15 +17,59 @@ out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
if __name__ == "__main__":
#buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
buf = sample
def frontend_inner(buf):
lexer = Lexer(buf)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
for ins in parser.instructions():
print("\n".join(
f"{value:08x}"
for value in emit_instruction(ins)
))
for ins, start_end in parser.instructions():
yield list(emit_instruction(ins)), start_end
def print_error(filename, buf, e):
assert len(e.args) == 2, e
message, token = e.args
lines = buf.splitlines()
line = lines[token.line - 1]
error_name = str(type(e).__name__)
col_indent = ' ' * token.col
col_pointer = '^' * len(token.lexeme)
RED = "\033[0;31m"
DEFAULT = "\033[0;0m"
print(f'File: "{filename}", line {token.line}, column {token.col}\n', file=sys.stderr)
sys.stderr.write(' ')
wrote_default = False
for i, c in enumerate(line.decode('utf-8')):
if i == token.col:
sys.stderr.write(RED)
sys.stderr.write(c)
if i == token.col + len(token.lexeme):
wrote_default = True
sys.stderr.write(DEFAULT)
if not wrote_default:
sys.stderr.write(DEFAULT)
sys.stderr.write('\n')
print(f" {RED}{col_indent}{col_pointer}{DEFAULT}", file=sys.stderr)
print(f'{RED}{error_name}{DEFAULT}: {message}', file=sys.stderr)
def frontend(filename, buf):
try:
yield from frontend_inner(buf)
except ParserError as e:
print_error(input_filename, buf, e)
raise
except LexerError as e:
print_error(input_filename, buf, e)
raise
if __name__ == "__main__":
input_filename = sys.argv[1]
#output_filename = sys.argv[2]
with open(input_filename, 'rb') as f:
buf = f.read()
output = list(frontend(input_filename, buf))
for cw, (start_ix, end_ix) in output:
if True:
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x},")
else:
source = buf[start_ix:end_ix]
print(f"0x{cw[0]:08x}, 0x{cw[1]:08x}, 0x{cw[2]:08x}, 0x{cw[3]:08x}, // {source.decode('utf-8')}")

View File

@ -60,23 +60,53 @@ def src_reg_type(kw):
else:
assert not "Invalid PVS_SRC_REG", kw
def emit_source(src: Source):
value = (
pvs_src.REG_TYPE_gen(src_reg_type(src.type))
| pvs_src.OFFSET_gen(src.offset)
| pvs_src.SWIZZLE_X_gen(src.swizzle.select[0])
| pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1])
| pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2])
| pvs_src.SWIZZLE_W_gen(src.swizzle.select[3])
| pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0]))
| pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1]))
| pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2]))
| pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3]))
)
def emit_source(src: Source, prev: Source):
if src is not None:
value = (
pvs_src.REG_TYPE_gen(src_reg_type(src.type))
| pvs_src.OFFSET_gen(src.offset)
| pvs_src.SWIZZLE_X_gen(src.swizzle.select[0])
| pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1])
| pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2])
| pvs_src.SWIZZLE_W_gen(src.swizzle.select[3])
| pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0]))
| pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1]))
| pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2]))
| pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3]))
)
else:
assert prev is not None
value = (
pvs_src.REG_TYPE_gen(src_reg_type(prev.type))
| pvs_src.OFFSET_gen(prev.offset)
| pvs_src.SWIZZLE_X_gen(7)
| pvs_src.SWIZZLE_Y_gen(7)
| pvs_src.SWIZZLE_Z_gen(7)
| pvs_src.SWIZZLE_W_gen(7)
| pvs_src.MODIFIER_X_gen(0)
| pvs_src.MODIFIER_Y_gen(0)
| pvs_src.MODIFIER_Z_gen(0)
| pvs_src.MODIFIER_W_gen(0)
)
yield value
def prev_source(ins, ix):
if ix == 0:
assert ins.source0 is not None
return ins.source0
elif ix == 1:
return ins.source0
elif ix == 2:
if ins.source1 is not None:
return ins.source1
else:
return ins.source0
else:
assert False, ix
def emit_instruction(ins: Instruction):
yield from emit_destination_op(ins.destination_op)
yield from emit_source(ins.source0)
yield from emit_source(ins.source1)
yield from emit_source(ins.source2)
yield from emit_source(ins.source0, prev_source(ins, 0))
yield from emit_source(ins.source1, prev_source(ins, 1))
yield from emit_source(ins.source2, prev_source(ins, 2))

View File

@ -21,6 +21,7 @@ class TT(Enum):
@dataclass
class Token:
start_ix: int
line: int
col: int
type: TT
@ -64,7 +65,7 @@ class Lexer:
return self.buf[self.current_ix]
def pos(self):
return self.line, self.col - (self.current_ix - self.start_ix)
return self.start_ix, self.line, self.col - (self.current_ix - self.start_ix)
def identifier(self):
while not self.at_end_p() and self.peek() in identifier_characters:
@ -96,7 +97,7 @@ class Lexer:
elif c == ord('.'):
return Token(*self.pos(), TT.dot, self.lexeme())
elif c == ord(';'):
while not at_end_p() and peek() != ord('\n'):
while not self.at_end_p() and self.peek() != ord('\n'):
self.advance()
elif c == ord(' ') or c == ord('\r') or c == ord('\t'):
pass

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from typing import Union
from assembler import lexer
from assembler.lexer import TT
from assembler.lexer import TT, Token
from assembler.keywords import KW, ME, VE
"""
@ -44,7 +44,7 @@ class Instruction:
source1: Source
source2: Source
class ParseError(Exception):
class ParserError(Exception):
pass
def identifier_to_number(token):
@ -52,7 +52,7 @@ def identifier_to_number(token):
assert token.type is TT.identifier
if not all(d in digits for d in token.lexeme):
raise ParseError("expected number", token)
raise ParserError("expected number", token)
return int(bytes(token.lexeme), 10)
def we_ord(c):
@ -66,9 +66,9 @@ def parse_dest_write_enable(token):
assert token.type is TT.identifier
we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we):
raise ParseError("expected destination write enable", token)
raise ParserError("expected destination write enable", token)
if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParseError("misleading non-sequential write enable", token)
raise ParserError("misleading non-sequential write enable", token)
return set(we_ord(c) for c in we)
def parse_source_swizzle(token):
@ -89,25 +89,25 @@ def parse_source_swizzle(token):
swizzle_modifiers = [None] * 4
lexeme = bytes(token.lexeme).lower()
while state < 4:
if ix > len(token.lexeme):
raise ParseError("invalid source swizzle", token)
if ix >= len(token.lexeme):
raise ParserError("invalid source swizzle", token)
c = lexeme[ix]
if c == ord('-'):
if (swizzle_modifiers[state] is not None) or (swizzle_selects[state] is not None):
raise ParseError("invalid source swizzle modifier", token)
raise ParserError("invalid source swizzle modifier", token)
swizzle_modifiers[state] = True
elif c in select_mapping:
if swizzle_selects[state] is not None:
raise ParseError("invalid source swizzle select", token)
raise ParserError("invalid source swizzle select", token)
swizzle_selects[state] = select_mapping[c]
if swizzle_modifiers[state] is None:
swizzle_modifiers[state] = False
state += 1
else:
raise ParseError("invalid source swizzle", token)
raise ParserError("invalid source swizzle", token)
ix += 1
if ix != len(lexeme):
raise ParseError("invalid source swizzle", token)
raise ParserError("invalid source swizzle", token)
return SourceSwizzle(swizzle_selects, swizzle_modifiers)
class Parser:
@ -115,8 +115,8 @@ class Parser:
self.current_ix = 0
self.tokens = tokens
def peek(self):
token = self.tokens[self.current_ix]
def peek(self, offset=0):
token = self.tokens[self.current_ix + offset]
#print(token)
return token
@ -135,20 +135,20 @@ class Parser:
def consume(self, token_type, message):
token = self.advance()
if token.type != token_type:
raise ParseError(message, token)
raise ParserError(message, token)
return token
def consume_either(self, token_type1, token_type2, message):
token = self.advance()
if token.type != token_type1 and token.type != token_type2:
raise ParseError(message, token)
raise ParserError(message, token)
return token
def destination_type(self):
token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
if token.keyword not in destination_keywords:
raise ParseError("expected destination type", token)
raise ParserError("expected destination type", token)
return token.keyword
def offset(self):
@ -161,7 +161,7 @@ class Parser:
def opcode(self):
token = self.consume(TT.keyword, "expected opcode")
if type(token.keyword) != VE and type(token.keyword) != ME:
raise ParseError("expected opcode", token)
raise ParserError("expected opcode", token)
return token.keyword
def destination_op(self):
@ -178,7 +178,7 @@ class Parser:
token = self.consume(TT.keyword, "expected source type")
source_keywords = {KW.temporary, KW.input, KW.constant, KW.alt_temporary}
if token.keyword not in source_keywords:
raise ParseError("expected source type", token)
raise ParserError("expected source type", token)
return token.keyword
def source_swizzle(self):
@ -196,12 +196,23 @@ class Parser:
def instruction(self):
while self.match(TT.eol):
self.advance()
first_token = self.peek()
destination_op = self.destination_op()
source0 = self.source()
source1 = self.source()
source2 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source1 = None
else:
source1 = self.source()
if self.match(TT.eol) or self.match(TT.eof):
source2 = None
else:
source2 = self.source()
last_token = self.peek(-1)
self.consume_either(TT.eol, TT.eof, "expected newline or EOF")
return Instruction(destination_op, source0, source1, source2)
return (
Instruction(destination_op, source0, source1, source2),
(first_token.start_ix, last_token.start_ix + len(last_token.lexeme))
)
def instructions(self):
while not self.match(TT.eof):

View File

@ -2,113 +2,9 @@ import pvs_src
import pvs_src_bits
import pvs_dst
import pvs_dst_bits
from pprint import pprint
import itertools
from functools import partial
code = [
0x00f00203,
0x00d10001,
0x01248001,
0x01248001,
]
# Radeon Compiler Program
# 0: MOV output[1].xyz, input[1].xyz_;
# 1: MOV output[0], input[0].xyz1;
# Final vertex program code:
# 0: op: 0x00702203 dst: 1o op: VE_ADD
# src0: 0x01d10021 reg: 1i swiz: X/ Y/ Z/ U
# src1: 0x01248021 reg: 1i swiz: 0/ 0/ 0/ 0
# src2: 0x01248021 reg: 1i swiz: 0/ 0/ 0/ 0
# 1: op: 0x00f00203 dst: 0o op: VE_ADD
# src0: 0x01510001 reg: 0i swiz: X/ Y/ Z/ 1
# src1: 0x01248001 reg: 0i swiz: 0/ 0/ 0/ 0
# src2: 0x01248001 reg: 0i swiz: 0/ 0/ 0/ 0
code = [
0x00702203,
0x01d10021,
0x01248021,
0x01248021,
0x00f00203,
0x01510001,
0x01248001,
0x01248001,
]
code = [
0x00f00003,
0x00d10022,
0x01248022,
0x01248022,
0x00f02003,
0x00d10022,
0x01248022,
0x01248022,
0x00100004,
0x01ff0002,
0x01ff0020,
0x01ff2000,
0x00100006,
0x01ff0000,
0x01248000,
0x01248000,
0x00100004,
0x01ff0000,
0x01ff4022,
0x01ff6022,
0x00100050,
0x00000000,
0x01248000,
0x01248000,
0x00f00204,
0x0165a000,
0x01690001,
0x01240000,
]
code = [
0x00f00003,
0x00d10022,
0x01248022,
0x01248022,
0x00f02003,
0x00d10022,
0x01248022,
0x01248022,
0x00100004,
0x01ff0002,
0x01ff0020,
0x01ff2000,
0x00100006,
0x01ff0000,
0x01248000,
0x01248000,
0x00100004,
0x01ff0000,
0x01ff4022,
0x01ff6022,
0x00200051,
0x00000000,
0x01248000,
0x01248000,
0x00100050,
0x00000000,
0x01248000,
0x01248000,
0x00600002,
0x01c8e001,
0x01c9e000,
0x01248000,
0x00500204,
0x1fe72001,
0x01e70000,
0x01e72000,
0x00a00204,
0x0138e001,
0x0138e000,
0x017ae000,
]
import sys
def out(level, *args):
sys.stdout.write(" " * level + " ".join(args))
@ -151,8 +47,6 @@ def parse_code(code):
ix += 4
#parse_code(code)
def dst_swizzle_from_we(dst_op):
table = [
(pvs_dst.WE_X, "x"),
@ -168,7 +62,7 @@ def dst_swizzle_from_we(dst_op):
_op_substitutions = [
("DOT_PRODUCT", "DOT"),
("MULTIPLY_ADD", "MAD"),
("FRACTION", "FRAC"),
("FRACTION", "FRC"),
("MULTIPLY", "MUL"),
("MAXMIUM", "MAX"),
("MINIMUM", "MIN"),
@ -280,5 +174,14 @@ def parse_instruction(instruction):
print(dst.ljust(12), "=", op.ljust(9), " ".join(map(lambda s: s.ljust(17), rest)))
for i in range(len(code) // 4):
parse_instruction(code[i*4:i*4+4])
def parse_hex(s):
assert s.startswith('0x')
return int(s.removeprefix('0x'), 16)
if __name__ == "__main__":
filename = sys.argv[1]
with open(filename) as f:
buf = f.read()
code = [parse_hex(c.strip()) for c in buf.split(',') if c.strip()]
for i in range(len(code) // 4):
parse_instruction(code[i*4:i*4+4])