from enum import Enum, auto from dataclasses import dataclass def __bit_array(c): assert c <= 255 and c >= 0, c for i in range(7, -1, -1): yield (c >> i) & 1 def _bit_array(b): for c in b: yield from __bit_array(c) def bit_array(b): return list(_bit_array(b)) def data_bit_pred(pn, n): return (n & (1 << (pn - 1))) != 0 class ParityResult(Enum): zero_bit = auto() one_bit = auto() two_bit = auto() def __repr__(self): cls_name = self.__class__.__name__ return f'{cls_name}.{self.name}' @dataclass class HammingCode: bits: int @property def block_length(self): return 2 ** self.bits - 1 @property def message_length(self): return 2 ** self.bits - self.bits - 1 def all_data_bits(self): power = 0 for bit_ix in range(1, self.block_length + 1): if bit_ix == 2 ** power: power += 1 else: yield bit_ix def interleave_parity(self, ba): assert len(ba) == self.message_length power = 0 index = 0 yield None # extended parity for bit_ix in range(1, self.block_length + 1): if bit_ix == 2 ** power: yield None power += 1 else: yield ba[index] index += 1 def parity_data_bits(self, power): mask = 2 ** power for bit_ix in range(1, self.block_length + 1): if bit_ix & mask: yield bit_ix def parity(self, pba, power): p = sum( pba[bit_ix] for bit_ix in self.parity_data_bits(power) if bit_ix != 2 ** power ) return p % 2 def extended_parity(self, pba): p = sum( pba[bit_ix] for bit_ix in range(1, self.block_length + 1) ) return p % 2 def encode_bitarray(self, ba): pba = list(self.interleave_parity(ba)) assert len(pba) == self.block_length + 1 for power in range(self.bits): parity_ix = 2 ** power assert pba[parity_ix] is None pba[parity_ix] = self.parity(pba, power) pba[0] = self.extended_parity(pba) return pba def encode_block(self, b): assert len(b) * 8 == self.message_length, (len(b), self.message_length) ba = bit_array(b) return self.encode_bitarray(ba) def _check_parity(self, pba): for power in range(self.bits): parity_ix = 2 ** power yield self.parity(pba, power) == pba[parity_ix] def check_parity(self, pba): error_bits = set(self.all_data_bits()) errors = 0 for power, match in enumerate(self._check_parity(pba)): data_bits = set(self.parity_data_bits(power)) if match: error_bits -= data_bits else: errors += 1 error_bits &= data_bits if errors > 0: if self.extended_parity(pba) == pba[0]: return ParityResult.two_bit, error_bits else: assert len(error_bits) == 1, error_bits return ParityResult.one_bit, error_bits else: assert len(error_bits) == 0, error_bits return ParityResult.zero_bit, error_bits hc = HammingCode(bits=4) pba = hc.encode_bitarray([ 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1 ]) for y in range(4): for x in range(4): bit = pba[y * 4 + x] print(bit, end=' ') print() pba = [ 1, 1, 0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 1, ] print(hc.check_parity(pba))