This commit is contained in:
Zack Buhman 2023-10-24 22:57:30 +00:00
commit d35acf99ff

144
hamming.py Normal file
View File

@ -0,0 +1,144 @@
from enum import Enum, auto
from dataclasses import dataclass
def __bit_array(c):
assert c <= 255 and c >= 0, c
for i in range(7, -1, -1):
yield (c >> i) & 1
def _bit_array(b):
for c in b:
yield from __bit_array(c)
def bit_array(b):
return list(_bit_array(b))
def data_bit_pred(pn, n):
return (n & (1 << (pn - 1))) != 0
class ParityResult(Enum):
zero_bit = auto()
one_bit = auto()
two_bit = auto()
def __repr__(self):
cls_name = self.__class__.__name__
return f'{cls_name}.{self.name}'
@dataclass
class HammingCode:
bits: int
@property
def block_length(self):
return 2 ** self.bits - 1
@property
def message_length(self):
return 2 ** self.bits - self.bits - 1
def all_data_bits(self):
power = 0
for bit_ix in range(1, self.block_length + 1):
if bit_ix == 2 ** power:
power += 1
else:
yield bit_ix
def interleave_parity(self, ba):
assert len(ba) == self.message_length
power = 0
index = 0
yield None # extended parity
for bit_ix in range(1, self.block_length + 1):
if bit_ix == 2 ** power:
yield None
power += 1
else:
yield ba[index]
index += 1
def parity_data_bits(self, power):
mask = 2 ** power
for bit_ix in range(1, self.block_length + 1):
if bit_ix & mask:
yield bit_ix
def parity(self, pba, power):
p = sum(
pba[bit_ix]
for bit_ix in self.parity_data_bits(power)
if bit_ix != 2 ** power
)
return p % 2
def extended_parity(self, pba):
p = sum(
pba[bit_ix]
for bit_ix in range(1, self.block_length + 1)
)
return p % 2
def encode_bitarray(self, ba):
pba = list(self.interleave_parity(ba))
assert len(pba) == self.block_length + 1
for power in range(self.bits):
parity_ix = 2 ** power
assert pba[parity_ix] is None
pba[parity_ix] = self.parity(pba, power)
pba[0] = self.extended_parity(pba)
return pba
def encode_block(self, b):
assert len(b) * 8 == self.message_length, (len(b), self.message_length)
ba = bit_array(b)
return self.encode_bitarray(ba)
def _check_parity(self, pba):
for power in range(self.bits):
parity_ix = 2 ** power
yield self.parity(pba, power) == pba[parity_ix]
def check_parity(self, pba):
error_bits = set(self.all_data_bits())
errors = 0
for power, match in enumerate(self._check_parity(pba)):
data_bits = set(self.parity_data_bits(power))
if match:
error_bits -= data_bits
else:
errors += 1
error_bits &= data_bits
if errors > 0:
if self.extended_parity(pba) == pba[0]:
return ParityResult.two_bit, error_bits
else:
assert len(error_bits) == 1, error_bits
return ParityResult.one_bit, error_bits
else:
assert len(error_bits) == 0, error_bits
return ParityResult.zero_bit, error_bits
hc = HammingCode(bits=4)
pba = hc.encode_bitarray([
1,
1, 0, 0,
1, 0, 1,
1, 0, 1, 1
])
for y in range(4):
for x in range(4):
bit = pba[y * 4 + x]
print(bit, end=' ')
print()
pba = [
1, 1, 0, 1,
0, 1, 0, 0,
1, 1, 0, 1,
1, 0, 1, 1,
]
print(hc.check_parity(pba))