simple-message-channels/simple_message_channels/smc.py

134 lines
3.8 KiB
Python

"""Sans I/O wire protocol for Hypercore"""
from typing import List, Tuple
import attr
from pyvarint import encode, encoding_length
__all__ = ["SimpleMessageChannel"]
@attr.s(auto_attribs=True)
class SimpleMessageChannel:
"""A simple message channel."""
message: bytes = attr.Factory(bytes)
messages: List[Tuple[int, int, bytes]] = attr.Factory(list)
varint: int = 0
factor: int = 1
length: int = 0
header: int = 0
state: int = 0
consumed = 0
max_size = 8 * 1024 * 1024
def send(self, channel: int, type: int, message: bytes) -> bytes:
"""Encode a channel, type and message data to be sent.
:param channel: the message channel identifier
:param type: the type of message
:param message: the message data
"""
header = channel << 4 or type
length = len(message) + encoding_length(header)
return encode(length) + encode(header) + message
def send_batch(self, messages: List[Tuple[int, int, bytes]]) -> bytes:
"""Encodes a series of messages into a single payload of bytes.
:param messages: Several data messages
"""
payload = b""
for (channel, type, message) in messages:
payload += self.send(channel, type, message)
return payload
def recv(self, data: bytes) -> None:
"""Encode a channel, type, message to be sent.
:param data: the message data
"""
offset = 0
while offset < len(data):
if self.state == 2:
offset = self._read_msg(data, offset)
else:
offset = self._read_varint(data, offset)
if self.state == 2 and self.length == 0:
self._read_msg(data, offset)
def _read_msg(self, data: bytes, offset: int) -> int:
"""Read a message chunk.
:param data: the message data
:param offset: the bytes offset
"""
free = len(data) - offset
if free >= self.length:
if self.message:
self.message += data
else:
self.message = data[offset : offset + self.length]
self._next_state()
offset += self.length
return offset
self.message += data
self.length -= free
return len(data)
def _read_varint(self, data: bytes, offset: int) -> int:
"""Read a varint from a byte stream.
:param data: the message data
:param offset: the bytes offset
"""
while offset < len(data):
self.varint += (data[offset] & 127) * self.factor
self.consumed += 1
if data[offset] < 128:
self._next_state()
offset += 1
return offset
self.factor *= 128
if self.consumed >= 8:
raise RuntimeError("Incoming varint is invalid")
return len(data)
def _next_state(self) -> None:
"""Calculate the next state."""
if self.state == 0:
self.state = 1
self.factor = 1
self.length = self.varint
self.consumed = 0
self.varint = 0
if not self.length:
self.varint = 0
elif self.state == 1:
self.state = 2
self.factor = 1
self.header = self.varint
self.length -= self.consumed
self.consumed = 0
self.varint = 0
if self.length < 0 or self.length > self.max_size:
raise RuntimeError("Incoming message too large")
elif self.state == 2:
self.state = 0
channel = self.header >> 4
type = self.header & 0b1111
self.messages.append((channel, type, self.message))
self.message = b""
else:
raise RuntimeError(f"Unknown state {self.state}")