assembler: feature parity with the disassembler

This commit is contained in:
Zack Buhman 2025-10-16 12:53:28 -05:00
parent e24b3ada5e
commit d903115964
7 changed files with 147 additions and 23 deletions

View File

@ -0,0 +1,29 @@
from assembler.lexer import Lexer
from assembler.parser import Parser
from assembler.emitter import emit_instruction
sample = b"""
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].yz = VE_MUL input[0]._xy_ temp[0]._yy_ temp[0].0000
out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_
out[0].yw = VE_MAD input[0]._x_0 temp[0]._x_0 temp[0]._z_1
"""
if __name__ == "__main__":
#buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
buf = sample
lexer = Lexer(buf)
tokens = list(lexer.lex_tokens())
parser = Parser(tokens)
for ins in parser.instructions():
print("\n".join(
f"{value:08x}"
for value in emit_instruction(ins)
))

82
regs/assembler/emitter.py Normal file
View File

@ -0,0 +1,82 @@
from assembler.keywords import ME, VE, KW
from assembler.parser import Instruction, DestinationOp, Source
import pvs_dst
import pvs_src
import pvs_dst_bits
import pvs_src_bits
def we_x(s):
return int(0 in s)
def we_y(s):
return int(1 in s)
def we_z(s):
return int(2 in s)
def we_w(s):
return int(3 in s)
def dst_reg_type(kw):
if kw == KW.temporary:
return pvs_dst_bits.PVS_DST_REG_gen["TEMPORARY"]
elif kw == KW.a0:
return pvs_dst_bits.PVS_DST_REG_gen["A0"]
elif kw == KW.out:
return pvs_dst_bits.PVS_DST_REG_gen["OUT"]
elif kw == KW.out_repl_x:
return pvs_dst_bits.PVS_DST_REG_gen["OUT_REPL_X"]
elif kw == KW.alt_temporary:
return pvs_dst_bits.PVS_DST_REG_gen["ALT_TEMPORARY"]
elif kw == KW.input:
return pvs_dst_bits.PVS_DST_REG_gen["INPUT"]
else:
assert not "Invalid PVS_DST_REG", kw
def emit_destination_op(dst_op: DestinationOp):
assert type(dst_op.opcode) in {ME, VE}
math_inst = int(type(dst_op.opcode) is ME)
value = (
pvs_dst.OPCODE_gen(dst_op.opcode.value)
| pvs_dst.MATH_INST_gen(math_inst)
| pvs_dst.REG_TYPE_gen(dst_reg_type(dst_op.type))
| pvs_dst.OFFSET_gen(dst_op.offset)
| pvs_dst.WE_X_gen(we_x(dst_op.write_enable))
| pvs_dst.WE_Y_gen(we_y(dst_op.write_enable))
| pvs_dst.WE_Z_gen(we_z(dst_op.write_enable))
| pvs_dst.WE_W_gen(we_w(dst_op.write_enable))
)
yield value
def src_reg_type(kw):
if kw == KW.temporary:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_TEMPORARY"]
elif kw == KW.input:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_INPUT"]
elif kw == KW.constant:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_CONSTANT"]
elif kw == KW.alt_temporary:
return pvs_src_bits.PVS_SRC_REG_TYPE_gen["PVS_SRC_REG_ALT_TEMPORARY"]
else:
assert not "Invalid PVS_SRC_REG", kw
def emit_source(src: Source):
value = (
pvs_src.REG_TYPE_gen(src_reg_type(src.type))
| pvs_src.OFFSET_gen(src.offset)
| pvs_src.SWIZZLE_X_gen(src.swizzle.select[0])
| pvs_src.SWIZZLE_Y_gen(src.swizzle.select[1])
| pvs_src.SWIZZLE_Z_gen(src.swizzle.select[2])
| pvs_src.SWIZZLE_W_gen(src.swizzle.select[3])
| pvs_src.MODIFIER_X_gen(int(src.swizzle.modifier[0]))
| pvs_src.MODIFIER_Y_gen(int(src.swizzle.modifier[1]))
| pvs_src.MODIFIER_Z_gen(int(src.swizzle.modifier[2]))
| pvs_src.MODIFIER_W_gen(int(src.swizzle.modifier[3]))
)
yield value
def emit_instruction(ins: Instruction):
yield from emit_destination_op(ins.destination_op)
yield from emit_source(ins.source0)
yield from emit_source(ins.source1)
yield from emit_source(ins.source2)

View File

@ -3,7 +3,7 @@ from enum import Enum, auto
from itertools import chain from itertools import chain
from typing import Union from typing import Union
import keywords from assembler import keywords
DEBUG = True DEBUG = True

View File

@ -1,15 +1,16 @@
import lexer
from lexer import TT
from keywords import KW, ME, VE
from itertools import pairwise from itertools import pairwise
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Union
from assembler import lexer
from assembler.lexer import TT
from assembler.keywords import KW, ME, VE
""" """
temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 temp[0].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000 temp[1].xyzw = VE_ADD const[1].xyzw const[1].0000 const[1].0000
temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___ temp[0].x = VE_MAD const[0].x___ temp[1].x___ temp[0].y___
temp[0].x = VE_FRAC temp[0].x___ temp[0].0000 temp[0].0000 temp[0].x = VE_FRC temp[0].x___ temp[0].0000 temp[0].0000
temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___ temp[0].x = VE_MAD temp[0].x___ const[1].z___ const[1].w___
temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000 temp[0].y = ME_COS temp[0].xxxx temp[0].0000 temp[0].0000
temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000 temp[0].x = ME_SIN temp[0].xxxx temp[0].0000 temp[0].0000
@ -54,15 +55,21 @@ def identifier_to_number(token):
raise ParseError("expected number", token) raise ParseError("expected number", token)
return int(bytes(token.lexeme), 10) return int(bytes(token.lexeme), 10)
def we_ord(c):
if c == ord("w"):
return 3
else:
return c - ord("x")
def parse_dest_write_enable(token): def parse_dest_write_enable(token):
we_chars = set(b"xyzw") we_chars = set(b"xyzw")
assert token.type is TT.identifier assert token.type is TT.identifier
we = bytes(token.lexeme).lower() we = bytes(token.lexeme).lower()
if not all(c in we_chars for c in we): if not all(c in we_chars for c in we):
raise ParseError("expected destination write enable", token) raise ParseError("expected destination write enable", token)
if not all(a < b for a, b in pairwise(we)) or len(set(we)) != len(we): if not all(we_ord(a) < we_ord(b) for a, b in pairwise(we)) or len(set(we)) != len(we):
raise ParseError("misleading non-sequential write enable", token) raise ParseError("misleading non-sequential write enable", token)
return set(c - ord('x') for c in we) return set(we_ord(c) for c in we)
def parse_source_swizzle(token): def parse_source_swizzle(token):
select_mapping = { select_mapping = {
@ -109,7 +116,9 @@ class Parser:
self.tokens = tokens self.tokens = tokens
def peek(self): def peek(self):
return self.tokens[self.current_ix] token = self.tokens[self.current_ix]
#print(token)
return token
def at_end_p(self): def at_end_p(self):
return self.peek().type == TT.eof return self.peek().type == TT.eof
@ -119,8 +128,8 @@ class Parser:
self.current_ix += 1 self.current_ix += 1
return token return token
def match(self, token_type, message): def match(self, token_type):
token = self.advance() token = self.peek()
return token.type == token_type return token.type == token_type
def consume(self, token_type, message): def consume(self, token_type, message):
@ -135,15 +144,6 @@ class Parser:
raise ParseError(message, token) raise ParseError(message, token)
return token return token
"""
def consume_keyword(self, keyword, message):
token = self.consume(TT.keyword, message)
assert token.keyword is not None
if token.keyword != keyword:
raise ParseError(message, token)
"""
def destination_type(self): def destination_type(self):
token = self.consume(TT.keyword, "expected destination type") token = self.consume(TT.keyword, "expected destination type")
destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input} destination_keywords = {KW.temporary, KW.a0, KW.out, KW.out_repl_x, KW.alt_temporary, KW.input}
@ -194,6 +194,8 @@ class Parser:
return Source(source_type, offset, source_swizzle) return Source(source_type, offset, source_swizzle)
def instruction(self): def instruction(self):
while self.match(TT.eol):
self.advance()
destination_op = self.destination_op() destination_op = self.destination_op()
source0 = self.source() source0 = self.source()
source1 = self.source() source1 = self.source()
@ -201,8 +203,12 @@ class Parser:
self.consume_either(TT.eol, TT.eof, "expected newline or EOF") self.consume_either(TT.eol, TT.eof, "expected newline or EOF")
return Instruction(destination_op, source0, source1, source2) return Instruction(destination_op, source0, source1, source2)
def instructions(self):
while not self.match(TT.eof):
yield self.instruction()
if __name__ == "__main__": if __name__ == "__main__":
from lexer import Lexer from assembler.lexer import Lexer
buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_" buf = b"out[0].xz = VE_MAD input[0].-y-_-0-_ temp[0].x_0_ temp[0].y_0_"
lexer = Lexer(buf) lexer = Lexer(buf)
tokens = list(lexer.lex_tokens()) tokens = list(lexer.lex_tokens())

View File

@ -67,12 +67,15 @@ def low_from_bits(bits):
return bits return bits
def generate_python(prefix, fields): def generate_python(prefix, fields):
#out(0, f"class {prefix}:")
fields = list(fields) fields = list(fields)
for field_name, bits, description in fields: for field_name, bits, description in fields:
#out(1, f"@staticmethod")
out(0, f"def {field_name}(n):") out(0, f"def {field_name}(n):")
out(1, f"return (n >> {low_from_bits(bits)}) & {mask_from_bits(bits)}") out(1, f"return (n >> {low_from_bits(bits)}) & {hex(mask_from_bits(bits))}")
out(0, "")
out(0, f"def {field_name}_gen(n):")
out(1, f"assert ({hex(mask_from_bits(bits))} & n) == n, (n, {hex(mask_from_bits(bits))})")
out(1, f"return n << {low_from_bits(bits)}")
out(0, "") out(0, "")
out(0, "table = [") out(0, "table = [")

View File

@ -27,4 +27,5 @@ while ix < len(lines):
print(f' {value.strip()}: "{key.strip()}",') print(f' {value.strip()}: "{key.strip()}",')
ix += 1 ix += 1
print("}") print("}")
print(f"{name}_gen = dict((v, k) for k, v in {name}.items())")
print() print()

View File

@ -4,6 +4,7 @@ PVS_SRC_REG_TYPE = {
2: "PVS_SRC_REG_CONSTANT", 2: "PVS_SRC_REG_CONSTANT",
3: "PVS_SRC_REG_ALT_TEMPORARY", 3: "PVS_SRC_REG_ALT_TEMPORARY",
} }
PVS_SRC_REG_TYPE_gen = dict((v, k) for k, v in PVS_SRC_REG_TYPE.items())
PVS_SRC_SWIZZLE_SEL = { PVS_SRC_SWIZZLE_SEL = {
0: "PVS_SRC_SELECT_X", 0: "PVS_SRC_SELECT_X",
@ -13,10 +14,12 @@ PVS_SRC_SWIZZLE_SEL = {
4: "PVS_SRC_SELECT_FORCE_0", 4: "PVS_SRC_SELECT_FORCE_0",
5: "PVS_SRC_SELECT_FORCE_1", 5: "PVS_SRC_SELECT_FORCE_1",
} }
PVS_SRC_SWIZZLE_SEL_gen = dict((v, k) for k, v in PVS_SRC_SWIZZLE_SEL.items())
PVS_SRC_ADDR_MODE = { PVS_SRC_ADDR_MODE = {
0: "Absolute addressing", 0: "Absolute addressing",
1: "Relative addressing using A0 register", 1: "Relative addressing using A0 register",
2: "Relative addressing using I0 register", 2: "Relative addressing using I0 register",
} }
PVS_SRC_ADDR_MODE_gen = dict((v, k) for k, v in PVS_SRC_ADDR_MODE.items())