Tests passing
This commit is contained in:
parent
f466714406
commit
2ea20165f4
@ -73,8 +73,8 @@ class SimpleMessageChannel:
|
|||||||
else:
|
else:
|
||||||
self.message = data[offset : offset + self.length]
|
self.message = data[offset : offset + self.length]
|
||||||
|
|
||||||
|
self._next_state()
|
||||||
offset += self.length
|
offset += self.length
|
||||||
self._next_state(data, offset)
|
|
||||||
return offset
|
return offset
|
||||||
|
|
||||||
self.message += data
|
self.message += data
|
||||||
@ -93,22 +93,19 @@ class SimpleMessageChannel:
|
|||||||
self.consumed += 1
|
self.consumed += 1
|
||||||
|
|
||||||
if data[offset] < 128:
|
if data[offset] < 128:
|
||||||
self._next_state(data, offset + 1)
|
self._next_state()
|
||||||
|
offset += 1
|
||||||
return offset
|
return offset
|
||||||
|
|
||||||
offset += 1
|
self.factor *= 128
|
||||||
|
|
||||||
if self.consumed >= 8:
|
if self.consumed >= 8:
|
||||||
raise RuntimeError("Incoming varint is invalid")
|
raise RuntimeError("Incoming varint is invalid")
|
||||||
|
|
||||||
return len(data)
|
return len(data)
|
||||||
|
|
||||||
def _next_state(self, data: bytes, offset: int) -> bool:
|
def _next_state(self) -> None:
|
||||||
"""Calculate the next state.
|
"""Calculate the next state."""
|
||||||
|
|
||||||
:param data: the message data
|
|
||||||
:param offset: the bytes offset
|
|
||||||
"""
|
|
||||||
if self.state == 0:
|
if self.state == 0:
|
||||||
self.state = 1
|
self.state = 1
|
||||||
self.factor = 1
|
self.factor = 1
|
||||||
@ -126,9 +123,11 @@ class SimpleMessageChannel:
|
|||||||
self.varint = 0
|
self.varint = 0
|
||||||
if self.length < 0 or self.length > self.max_size:
|
if self.length < 0 or self.length > self.max_size:
|
||||||
raise RuntimeError("Incoming message too large")
|
raise RuntimeError("Incoming message too large")
|
||||||
else:
|
elif self.state == 2:
|
||||||
self.state = 0
|
self.state = 0
|
||||||
self.messages.append(
|
channel = self.header >> 4
|
||||||
(self.header >> 4, self.header & 0b1111, self.message)
|
type = self.header & 0b1111
|
||||||
)
|
self.messages.append((channel, type, self.message))
|
||||||
self.message = b""
|
self.message = b""
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Unknown state {self.state}")
|
||||||
|
@ -7,7 +7,7 @@ def test_smc_recv(smc1, smc2):
|
|||||||
smc2.recv(payload)
|
smc2.recv(payload)
|
||||||
|
|
||||||
assert len(smc2.messages) == 1
|
assert len(smc2.messages) == 1
|
||||||
assert smc2.messages[0] == [(0, 1, b"foo")]
|
assert smc2.messages == [(0, 1, b"foo")]
|
||||||
|
|
||||||
|
|
||||||
def test_smc_recv_multiple(smc1, smc2):
|
def test_smc_recv_multiple(smc1, smc2):
|
||||||
@ -15,7 +15,7 @@ def test_smc_recv_multiple(smc1, smc2):
|
|||||||
smc2.recv(payload)
|
smc2.recv(payload)
|
||||||
|
|
||||||
assert len(smc2.messages) == 3
|
assert len(smc2.messages) == 3
|
||||||
assert smc2.messages[0] == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]
|
assert smc2.messages == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]
|
||||||
|
|
||||||
|
|
||||||
def test_smc_recv_empty(smc1, smc2):
|
def test_smc_recv_empty(smc1, smc2):
|
||||||
@ -23,24 +23,24 @@ def test_smc_recv_empty(smc1, smc2):
|
|||||||
smc2.recv(payload)
|
smc2.recv(payload)
|
||||||
|
|
||||||
assert len(smc2.messages) == 1
|
assert len(smc2.messages) == 1
|
||||||
assert smc2.messages[0] == [(0, 1, b"")]
|
assert smc2.messages == [(0, 1, b"")]
|
||||||
|
|
||||||
|
|
||||||
def test_smc_recv_chunked(smc1, smc2):
|
def test_smc_recv_chunked(smc1, smc2):
|
||||||
payload = smc1.send(0, 1, b"foo")
|
payload = smc1.send(0, 1, b"foo")
|
||||||
|
|
||||||
for idx in range(0, len(payload)):
|
for idx in range(0, len(payload)):
|
||||||
smc2.recv(payload[idx, idx + 1])
|
smc2.recv(payload[idx : idx + 1])
|
||||||
|
|
||||||
assert len(smc2.messages) == 1
|
assert len(smc2.messages) == 1
|
||||||
assert smc2.messages[0] == [(0, 1, b"foo")]
|
assert smc2.messages == [(0, 1, b"foo")]
|
||||||
|
|
||||||
|
|
||||||
def test_smc_recv_chunked_multiple(smc1, smc2):
|
def test_smc_recv_chunked_multiple(smc1, smc2):
|
||||||
payload = smc1.send_batch([(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")])
|
payload = smc1.send_batch([(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")])
|
||||||
|
|
||||||
for idx in range(0, len(payload)):
|
for idx in range(0, len(payload)):
|
||||||
smc2.recv(payload[idx, idx + 1])
|
smc2.recv(payload[idx : idx + 1])
|
||||||
|
|
||||||
assert len(smc2.messages) == 3
|
assert len(smc2.messages) == 3
|
||||||
assert smc2.messages[0] == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]
|
assert smc2.messages == [(0, 1, b"foo"), (0, 1, b"bar"), (0, 1, b"baz")]
|
||||||
|
Loading…
Reference in New Issue
Block a user