Tests passing

This commit is contained in:
Luke Murphy 2020-08-05 07:58:19 +02:00
parent f466714406
commit 2ea20165f4
No known key found for this signature in database
GPG Key ID: 5E2EF5A63E3718CC
2 changed files with 19 additions and 20 deletions

View File

@ -73,8 +73,8 @@ class SimpleMessageChannel:
else: else:
self.message = data[offset : offset + self.length] self.message = data[offset : offset + self.length]
self._next_state()
offset += self.length offset += self.length
self._next_state(data, offset)
return offset return offset
self.message += data self.message += data
@ -93,22 +93,19 @@ class SimpleMessageChannel:
self.consumed += 1 self.consumed += 1
if data[offset] < 128: if data[offset] < 128:
self._next_state(data, offset + 1) self._next_state()
offset += 1
return offset return offset
offset += 1 self.factor *= 128
if self.consumed >= 8: if self.consumed >= 8:
raise RuntimeError("Incoming varint is invalid") raise RuntimeError("Incoming varint is invalid")
return len(data) return len(data)
def _next_state(self, data: bytes, offset: int) -> bool: def _next_state(self) -> None:
"""Calculate the next state. """Calculate the next state."""
:param data: the message data
:param offset: the bytes offset
"""
if self.state == 0: if self.state == 0:
self.state = 1 self.state = 1
self.factor = 1 self.factor = 1
@ -126,9 +123,11 @@ class SimpleMessageChannel:
self.varint = 0 self.varint = 0
if self.length < 0 or self.length > self.max_size: if self.length < 0 or self.length > self.max_size:
raise RuntimeError("Incoming message too large") raise RuntimeError("Incoming message too large")
else: elif self.state == 2:
self.state = 0 self.state = 0
self.messages.append( channel = self.header >> 4
(self.header >> 4, self.header & 0b1111, self.message) type = self.header & 0b1111
) self.messages.append((channel, type, self.message))
self.message = b"" self.message = b""
else:
raise RuntimeError(f"Unknown state {self.state}")

View File

@ -7,7 +7,7 @@ def test_smc_recv(smc1, smc2):
smc2.recv(payload) smc2.recv(payload)
assert len(smc2.messages) == 1 assert len(smc2.messages) == 1
assert smc2.messages[0] == [(0, 1, b"foo")] assert smc2.messages == [(0, 1, b"foo")]
def test_smc_recv_multiple(smc1, smc2): def test_smc_recv_multiple(smc1, smc2):
@ -15,7 +15,7 @@ def test_smc_recv_multiple(smc1, smc2):
smc2.recv(payload) smc2.recv(payload)
assert len(smc2.messages) == 3 assert len(smc2.messages) == 3
assert smc2.messages[0] == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")] assert smc2.messages == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]
def test_smc_recv_empty(smc1, smc2): def test_smc_recv_empty(smc1, smc2):
@ -23,24 +23,24 @@ def test_smc_recv_empty(smc1, smc2):
smc2.recv(payload) smc2.recv(payload)
assert len(smc2.messages) == 1 assert len(smc2.messages) == 1
assert smc2.messages[0] == [(0, 1, b"")] assert smc2.messages == [(0, 1, b"")]
def test_smc_recv_chunked(smc1, smc2): def test_smc_recv_chunked(smc1, smc2):
payload = smc1.send(0, 1, b"foo") payload = smc1.send(0, 1, b"foo")
for idx in range(0, len(payload)): for idx in range(0, len(payload)):
smc2.recv(payload[idx, idx + 1]) smc2.recv(payload[idx : idx + 1])
assert len(smc2.messages) == 1 assert len(smc2.messages) == 1
assert smc2.messages[0] == [(0, 1, b"foo")] assert smc2.messages == [(0, 1, b"foo")]
def test_smc_recv_chunked_multiple(smc1, smc2): def test_smc_recv_chunked_multiple(smc1, smc2):
payload = smc1.send_batch([(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]) payload = smc1.send_batch([(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")])
for idx in range(0, len(payload)): for idx in range(0, len(payload)):
smc2.recv(payload[idx, idx + 1]) smc2.recv(payload[idx : idx + 1])
assert len(smc2.messages) == 3 assert len(smc2.messages) == 3
assert smc2.messages[0] == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")] assert smc2.messages == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]