Source code for mrcrowbar.bits

from __future__ import annotations

from typing import Callable

from mrcrowbar.common import BytesReadType, BytesWriteType
from mrcrowbar.encoding import EndianEncoding

BYTE_REVERSE = bytes.fromhex(
    "008040c020a060e0109050d030b070f0"
    "088848c828a868e8189858d838b878f8"
    "048444c424a464e4149454d434b474f4"
    "0c8c4ccc2cac6cec1c9c5cdc3cbc7cfc"
    "028242c222a262e2129252d232b272f2"
    "0a8a4aca2aaa6aea1a9a5ada3aba7afa"
    "068646c626a666e6169656d636b676f6"
    "0e8e4ece2eae6eee1e9e5ede3ebe7efe"
    "018141c121a161e1119151d131b171f1"
    "098949c929a969e9199959d939b979f9"
    "058545c525a565e5159555d535b575f5"
    "0d8d4dcd2dad6ded1d9d5ddd3dbd7dfd"
    "038343c323a363e3139353d333b373f3"
    "0b8b4bcb2bab6beb1b9b5bdb3bbb7bfb"
    "078747c727a767e7179757d737b777f7"
    "0f8f4fcf2faf6fef1f9f5fdf3fbf7fff"
)

BIT_MASK = [(1 << size) - 1 for size in range( 0, 65 )]

mask: Callable[[int], int] = (
    lambda size: BIT_MASK[size] if size in range( 0, 65 ) else (1 << size) - 1
)


[docs]def reverse_bits( number: int, size: int = 8 ) -> int: number &= mask( size ) if size == 8: return BYTE_REVERSE[number] result = 0 width = (size + 7) // 8 shift = (8 - size) % 8 for i in range( 0, width ): fragment = (number >> (i * 8)) & 0xff result |= BYTE_REVERSE[fragment] << ((width - i - 1) * 8) result >>= shift return result
[docs]def read_bits( buffer: BytesReadType, byte_offset: int, bit_offset: int, size: int, bytes_reverse: bool = False, bit_endian: EndianEncoding = "big", io_endian: EndianEncoding = "big", ) -> int: byte_start = byte_offset bit_start = bit_offset bit_diff = (bit_offset + size) if bit_endian == "big" else (7 - bit_offset + size) if bytes_reverse: byte_end = byte_offset - bit_diff // 8 else: byte_end = byte_offset + bit_diff // 8 bit_end = bit_diff % 8 if bit_endian == "big" else 7 - (bit_diff % 8) result = 0 first_byte = buffer[byte_start] middle_bytes = ( range( byte_start + 1, byte_end ) if not bytes_reverse else range( byte_start - 1, byte_end, -1 ) ) end_byte = buffer[byte_end] if byte_end in range( len( buffer ) ) else 0 if bit_endian == "big": # start span_mask = mask( 8 - bit_start ) if byte_start == byte_end: span_mask ^= mask( 8 - bit_end ) result |= first_byte & span_mask if byte_start != byte_end: # middle for i in middle_bytes: result <<= 8 result |= buffer[i] # end span_mask = 0xff ^ mask( 8 - bit_end ) result <<= 8 result |= end_byte & span_mask result >>= 8 - bit_end else: # start span_mask = 0xff ^ mask( 7 - bit_start ) if byte_start == byte_end: span_mask ^= 0xff ^ mask( 7 - bit_end ) result |= first_byte & span_mask result >>= 7 - bit_start if byte_start != byte_end: bit_offset = bit_start + 1 # middle for i in middle_bytes: result |= buffer[i] << bit_offset bit_offset += 8 # end span_mask = mask( 7 - bit_end ) result |= (end_byte & span_mask) << bit_offset if io_endian != bit_endian: result = reverse_bits( result, size ) return result
[docs]def write_bits( value: int, buffer: BytesWriteType, byte_offset: int, bit_offset: int, size: int, bytes_reverse: bool = False, bit_endian: EndianEncoding = "big", io_endian: EndianEncoding = "big", ) -> None: if value not in range( 1 << size ): raise ValueError( f"Value {value} does not fit into {size} bits" ) byte_start = byte_offset bit_start = bit_offset if io_endian != bit_endian: value = reverse_bits( value, size ) bit_diff = (bit_offset + size) if bit_endian == "big" else (7 - bit_offset + size) if bytes_reverse: byte_end = byte_offset - bit_diff // 8 else: byte_end = byte_offset + bit_diff // 8 bit_end = bit_diff % 8 if bit_endian == "big" else 7 - (bit_diff % 8) middle_bytes = ( range( byte_start + 1, byte_end ) if not bytes_reverse else range( byte_start - 1, byte_end, -1 ) ) if bit_endian == "big": # start span_mask = mask( 8 - bit_start ) if byte_start == byte_end: span_mask ^= mask( 8 - bit_end ) start_value = value << 8 - bit_end else: start_value = value >> size - (8 - bit_start) buffer[byte_start] = (0xff ^ span_mask) & buffer[byte_start] | ( start_value & span_mask ) if byte_start != byte_end: # middle for i, x in enumerate( middle_bytes ): buffer[x] = ( value >> ((len( middle_bytes ) - i - 1) * 8 + bit_end) ) & 0xff # end end_value = value << (8 - bit_end) span_mask = 0xff ^ mask( 8 - bit_end ) if span_mask: buffer[byte_end] = (0xff ^ span_mask) & buffer[byte_end] | ( end_value & span_mask ) else: # start span_mask = 0xff ^ mask( 7 - bit_start ) if byte_start == byte_end: span_mask ^= 0xff ^ mask( 7 - bit_end ) start_value = value << 7 - bit_start buffer[byte_start] = (0xff ^ span_mask) & buffer[byte_start] | ( start_value & span_mask ) if byte_start != byte_end: bit_offset = bit_start + 1 # middle for i, x in enumerate( middle_bytes ): buffer[x] = (value >> bit_offset) & 0xff bit_offset += 8 # end span_mask = mask( 7 - bit_end ) end_value = value >> bit_offset if span_mask: buffer[byte_end] = (0xff ^ span_mask) & buffer[byte_end] | ( end_value & span_mask ) return
[docs]def reverse_bytes( buffer: BytesReadType ) -> bytes: output = [reverse_bits( x ) for x in buffer] output.reverse() return bytes( output )
[docs]def unpack_bits( byte: int ) -> int: """Expand a bitfield into a 64-bit int (8 bool bytes).""" longbits = byte & (0x00000000000000ff) longbits = (longbits | (longbits << 28)) & (0x0000000f0000000f) longbits = (longbits | (longbits << 14)) & (0x0003000300030003) longbits = (longbits | (longbits << 7)) & (0x0101010101010101) return longbits
[docs]def pack_bits( longbits: int ) -> int: """Crunch a 64-bit int (8 bool bytes) into a bitfield.""" byte = longbits & (0x0101010101010101) byte = (byte | (byte >> 7)) & (0x0003000300030003) byte = (byte | (byte >> 14)) & (0x0000000f0000000f) byte = (byte | (byte >> 28)) & (0x00000000000000ff) return byte
[docs]class BitStream: buffer: bytearray byte_pos: int bit_pos: int bit_endian: EndianEncoding io_endian: EndianEncoding def __init__( self, buffer: BytesReadType | None = None, start_offset: int | tuple[int, int] | None = None, bytes_reverse: bool = False, bit_endian: EndianEncoding = "big", io_endian: EndianEncoding = "big", ) -> None: """Create a BitStream instance. buffer Target byte array to read/write from. Defaults to an empty array. start_offset Position in the target to start reading from. Can be an integer byte offset, or a tuple containing the byte and bit offsets. Defaults to the start of the stream, depending on the endianness and ordering options. bytes_reverse If enabled, fetch successive bytes from the source in reverse order. bit_endian Endianness of the backing storage; either 'big' or 'little'. Defaults to big (i.e. starting from the most-significant bit (0x80) through least-significant bit (0x10)). io_endian Endianness of data returned from read/write; either 'big' or 'little'. Defaults to big (i.e. starting from the most-significant bit (0x80) through least-significant bit (0x10)). """ if buffer is None: self.buffer = bytearray() else: self.buffer = bytearray( buffer ) self.bytes_reverse = bytes_reverse if bit_endian not in ("big", "little"): raise TypeError( "bit_endian should be either 'big' or 'little'" ) self.bit_endian = bit_endian if io_endian not in ("big", "little"): raise TypeError( "io_endian should be either 'big' or 'little'" ) self.io_endian = io_endian if start_offset is None: self.byte_pos = len( self.buffer ) - 1 if bytes_reverse else 0 self.bit_pos = 0 if bit_endian == "big" else 7 elif isinstance( start_offset, int ): self.byte_pos = start_offset self.bit_pos = 0 if bit_endian == "big" else 7 elif isinstance( start_offset, tuple ): self.byte_pos, self.bit_pos = start_offset else: raise TypeError( "start_offset should be of type int or tuple" )
[docs] def tell( self ) -> tuple[int, int]: """Get the current byte and bit position.""" return self.byte_pos, self.bit_pos
[docs] def read( self, count: int ) -> int: """Get an integer containing the next [count] bits from the source.""" """ x.read( 3 ) # 0bABC x.read( 3 ) # 0bDEF x.read( 3 ) # 0bGHI x.read( 3 ) # 0bJKL # default: # ABCDEFGH IJKLxxxx # bit_endian == 'little' # HGFEDCBA xxxxLKJI # bytes_reverse == True: # IJKLxxxx ABCDEFGH # io_endian == 'little': # CBAFEDIH GLKJxxxx """ result = read_bits( buffer=self.buffer, byte_offset=self.byte_pos, bit_offset=self.bit_pos, size=count, bytes_reverse=self.bytes_reverse, bit_endian=self.bit_endian, io_endian=self.io_endian, ) self.seek( (count // 8, count % 8), origin="current" ) return result
[docs] def write( self, value: int, count: int ) -> None: """Write an unsigned integer containing [count] bits to the source.""" """ x.write( 0bABC, 3 ) x.write( 0bDEF, 3 ) x.write( 0bGHI, 3 ) x.write( 0bJKL, 3 ) # default: # ABCDEFGH IJKLxxxx # bit_endian == 'little' # HGFEDCBA xxxxLKJI # bytes_reverse == True: # IJKLxxxx ABCDEFGH # io_endian == 'little': # CBAFEDIH GLKJxxxx """ bit_diff = ( (self.bit_pos + count - 1) if self.bit_endian == "big" else (7 - self.bit_pos + count - 1) ) new_byte_pos = self.byte_pos if self.bytes_reverse: new_byte_pos -= bit_diff // 8 else: new_byte_pos += bit_diff // 8 if new_byte_pos < 0: byte_count = -new_byte_pos self.buffer = bytearray( b"\x00" * byte_count ) + self.buffer self.byte_pos += byte_count elif new_byte_pos >= len( self.buffer ): byte_count = new_byte_pos - len( self.buffer ) + 1 self.buffer = self.buffer + bytearray( b"\x00" * byte_count ) write_bits( value=value, buffer=self.buffer, byte_offset=self.byte_pos, bit_offset=self.bit_pos, size=count, bytes_reverse=self.bytes_reverse, bit_endian=self.bit_endian, io_endian=self.io_endian, ) self.seek( (count // 8, count % 8), origin="current" )
[docs] def seek( self, offset: int | tuple[int, int], origin: str = "start" ) -> None: """Seek to a location in the target. offset Relative offset in the target to move to. Can be an integer byte offset, or a tuple containing the byte and bit offsets. origin Position to measure the offset from. Can be either "start", "current" or "end". Defaults to "start". """ count: int = 0 if isinstance( offset, int ): count = offset * 8 elif isinstance( offset, tuple ): count = offset[0] * 8 + offset[1] else: raise TypeError( "offset should be of type int or tuple" ) if origin not in ("start", "current", "end"): raise TypeError( 'origin should be one of "start", "current" or "end"' ) if origin in ("start", "end"): if (origin == "start") ^ (self.bytes_reverse == False): self.byte_pos = len( self.buffer ) self.bit_pos = 0 else: self.byte_pos = 0 self.bit_pos = 0 bit_diff = ( (self.bit_pos + count) if self.bit_endian == "big" else (7 - self.bit_pos + count) ) if self.bytes_reverse: self.byte_pos -= bit_diff // 8 else: self.byte_pos += bit_diff // 8 self.bit_pos = bit_diff % 8 if self.bit_endian == "big" else 7 - (bit_diff % 8)
[docs] def in_bounds( self ) -> bool: """Returns True if the current position is within the bounds of the target.""" return self.byte_pos in range( len( self.buffer ) )
[docs] def get_buffer( self ) -> bytes: """Return a byte string containing the target.""" return bytes( self.buffer )