simple-message-channels/simple_message_channels/smc.py

120 lines
3.4 KiB
Python
Raw Normal View History

"""Sans I/O wire protocol for Hypercore"""
2020-08-03 20:23:24 +00:00
from typing import List, Optional, Tuple
import attr
from pyvarint import encode, encoding_length
2020-07-07 22:20:39 +00:00
__all__ = ["SimpleMessageChannel"]
@attr.s(auto_attribs=True)
class SimpleMessageChannel:
"""A simple message channel."""
message: bytes = attr.Factory(bytes)
messages: List[Tuple[int, int, bytes]] = attr.Factory(list)
2020-08-03 20:34:24 +00:00
2020-08-03 20:23:24 +00:00
varint: int = 0
factor: int = 1
length: int = 0
header: int = 0
state: int = 0
consumed = 0
max_size = 8 * 1024 * 1024
def send(self, channel: int, type: int, message: bytes) -> bytes:
"""Encode a channel, type and message data to be sent.
:param channel: the message channel identifier
:param type: the type of message
:param message: the message data
"""
header = channel << 4 or type
length = len(message) + encoding_length(header)
return encode(length) + encode(header) + message
def send_batch(self, messages: List[Tuple[int, int, bytes]]) -> bytes:
"""Encodes a series of messages into a single payload of bytes.
:param messages: Several data messages
"""
payload = b""
for (channel, type, message) in messages:
payload += self.send(channel, type, message)
return payload
2020-08-03 20:23:24 +00:00
def recv(self, data: bytes) -> None:
"""Encode a channel, type, message to be sent.
:param data: the message data
"""
2020-08-03 20:23:24 +00:00
offset = 0
while offset < len(data):
if self.state == 2:
offset = self._read_msg(data, offset)
else:
offset = self._read_varint(data, offset)
if self.state == 2 and self.length == 0:
self._read_msg(data, offset)
def _read_msg(self, data: bytes, offset: int) -> int:
"""TODO."""
free = len(data) - offset
if free >= self.length:
if self.message:
self.message += data
else:
self.message = data[offset : offset + self.length]
offset += self.length
self._next_state(data, offset)
return offset
self.message += data
self.length -= free
return len(data)
2020-08-03 20:23:24 +00:00
def _read_varint(self, data: bytes, offset: int) -> int:
"""TODO."""
2020-08-03 20:48:42 +00:00
while offset < len(data):
self.varint += (data[offset] & 127) * self.factor
self.consumed += 1
if data[offset] < 128:
self._next_state(data, offset + 1)
return offset
2020-08-03 20:48:42 +00:00
offset += 1
if self.consumed >= 8:
raise RuntimeError("Incoming varint is invalid")
return len(data)
2020-08-03 20:23:24 +00:00
2020-08-03 20:48:42 +00:00
def _next_state(self, data: bytes, offset: int) -> bool:
2020-08-03 20:23:24 +00:00
"""TODO."""
if self.state == 0:
self.state = 1
self.factor = 1
self.length = self.varint
self.consumed = 0
self.varint = 0
if not self.length:
self.varint = 0
elif self.state == 1:
self.state = 2
self.factor = 1
self.header = self.varint
self.length -= self.consumed
self.consumed = 0
self.varint = 0
if self.length < 0 or self.length > self.max_size:
raise RuntimeError("Incoming message too large")
else:
self.state = 0
self.message = b""