From 2ea20165f41bdb0f7f9134674d73129f8d1c426b Mon Sep 17 00:00:00 2001 From: Luke Murphy Date: Wed, 5 Aug 2020 07:58:19 +0200 Subject: [PATCH] Tests passing --- simple_message_channels/smc.py | 25 ++++++++++++------------- test/test_smc.py | 14 +++++++------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/simple_message_channels/smc.py b/simple_message_channels/smc.py index f02cf6d..06a2e28 100644 --- a/simple_message_channels/smc.py +++ b/simple_message_channels/smc.py @@ -73,8 +73,8 @@ class SimpleMessageChannel: else: self.message = data[offset : offset + self.length] + self._next_state() offset += self.length - self._next_state(data, offset) return offset self.message += data @@ -93,22 +93,19 @@ class SimpleMessageChannel: self.consumed += 1 if data[offset] < 128: - self._next_state(data, offset + 1) + self._next_state() + offset += 1 return offset - offset += 1 + self.factor *= 128 if self.consumed >= 8: raise RuntimeError("Incoming varint is invalid") return len(data) - def _next_state(self, data: bytes, offset: int) -> bool: - """Calculate the next state. - - :param data: the message data - :param offset: the bytes offset - """ + def _next_state(self) -> None: + """Calculate the next state.""" if self.state == 0: self.state = 1 self.factor = 1 @@ -126,9 +123,11 @@ class SimpleMessageChannel: self.varint = 0 if self.length < 0 or self.length > self.max_size: raise RuntimeError("Incoming message too large") - else: + elif self.state == 2: self.state = 0 - self.messages.append( - (self.header >> 4, self.header & 0b1111, self.message) - ) + channel = self.header >> 4 + type = self.header & 0b1111 + self.messages.append((channel, type, self.message)) self.message = b"" + else: + raise RuntimeError(f"Unknown state {self.state}") diff --git a/test/test_smc.py b/test/test_smc.py index 402f859..262d006 100644 --- a/test/test_smc.py +++ b/test/test_smc.py @@ -7,7 +7,7 @@ def test_smc_recv(smc1, smc2): smc2.recv(payload) assert len(smc2.messages) == 1 - assert smc2.messages[0] == [(0, 1, b"foo")] + assert smc2.messages == [(0, 1, b"foo")] def test_smc_recv_multiple(smc1, smc2): @@ -15,7 +15,7 @@ def test_smc_recv_multiple(smc1, smc2): smc2.recv(payload) assert len(smc2.messages) == 3 - assert smc2.messages[0] == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")] + assert smc2.messages == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")] def test_smc_recv_empty(smc1, smc2): @@ -23,24 +23,24 @@ def test_smc_recv_empty(smc1, smc2): smc2.recv(payload) assert len(smc2.messages) == 1 - assert smc2.messages[0] == [(0, 1, b"")] + assert smc2.messages == [(0, 1, b"")] def test_smc_recv_chunked(smc1, smc2): payload = smc1.send(0, 1, b"foo") for idx in range(0, len(payload)): - smc2.recv(payload[idx, idx + 1]) + smc2.recv(payload[idx : idx + 1]) assert len(smc2.messages) == 1 - assert smc2.messages[0] == [(0, 1, b"foo")] + assert smc2.messages == [(0, 1, b"foo")] def test_smc_recv_chunked_multiple(smc1, smc2): payload = smc1.send_batch([(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]) for idx in range(0, len(payload)): - smc2.recv(payload[idx, idx + 1]) + smc2.recv(payload[idx : idx + 1]) assert len(smc2.messages) == 3 - assert smc2.messages[0] == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")] + assert smc2.messages == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]