diff --git a/simple_message_channels/smc.py b/simple_message_channels/smc.py index 1063e2b..89f90c0 100644 --- a/simple_message_channels/smc.py +++ b/simple_message_channels/smc.py @@ -11,6 +11,7 @@ __all__ = ["SimpleMessageChannel"] class SimpleMessageChannel: """A simple message channel.""" + message: bytes = attr.Factory(bytes) messages: List[Tuple[int, int, bytes]] = attr.Factory(list) varint: int = 0 @@ -60,7 +61,22 @@ class SimpleMessageChannel: def _read_msg(self, data: bytes, offset: int) -> int: """TODO.""" - pass + free = len(data) - offset + + if free >= self.length: + if self.message: + self.message += data + else: + self.message = data[offset : offset + self.length] + + offset += self.length + self._next_state(data, offset) + return offset + + self.message += data + self.length -= free + + return len(data) def _read_varint(self, data: bytes, offset: int) -> int: """TODO.""" @@ -69,10 +85,8 @@ class SimpleMessageChannel: self.consumed += 1 if data[offset] < 128: - state = self._next_state(data, offset + 1) - if state: - return offset - return len(data) + self._next_state(data, offset + 1) + return offset offset += 1 @@ -83,4 +97,23 @@ class SimpleMessageChannel: def _next_state(self, data: bytes, offset: int) -> bool: """TODO.""" - pass + if self.state == 0: + self.state = 1 + self.factor = 1 + self.length = self.varint + self.consumed = 0 + self.varint = 0 + if not self.length: + self.varint = 0 + elif self.state == 1: + self.state = 2 + self.factor = 1 + self.header = self.varint + self.length -= self.consumed + self.consumed = 0 + self.varint = 0 + if self.length < 0 or self.length > self.max_size: + raise RuntimeError("Incoming message too large") + else: + self.state = 0 + self.message = b""