diff --git a/merkle_tree_stream/__init__.py b/merkle_tree_stream/__init__.py index 00aa539..57b5331 100644 --- a/merkle_tree_stream/__init__.py +++ b/merkle_tree_stream/__init__.py @@ -1,7 +1,7 @@ """merkle-tree-stream module.""" -from merkle_tree_stream.generator import ( # noqa - MerkleTreeGenerator, +from merkle_tree_stream.generate import ( # noqa + MerkleTreeIterator, MerkleTreeNode, ) diff --git a/merkle_tree_stream/generator.py b/merkle_tree_stream/generate.py similarity index 59% rename from merkle_tree_stream/generator.py rename to merkle_tree_stream/generate.py index 20035e2..8a489cf 100644 --- a/merkle_tree_stream/generator.py +++ b/merkle_tree_stream/generate.py @@ -1,13 +1,13 @@ """The merkle tree stream generator.""" -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Iterator, List, Optional import attr from flat_tree import FlatTreeAccessor Hash = str -__all__ = ['MerkleTreeGenerator', 'MerkleTreeNode'] +__all__ = ['MerkleTreeIterator', 'MerkleTreeNode'] flat_tree = FlatTreeAccessor() @@ -26,7 +26,7 @@ class MerkleTreeNode: index: int parent: int size: int - data: Optional[bytes] + data: bytes hash: Optional[str] = None def __attrs_post_init__(self) -> Any: @@ -35,29 +35,62 @@ class MerkleTreeNode: @attr.s(auto_attribs=True) -class MerkleTreeGenerator: - """A stream that generates a merkle tree based on the incoming data. +class MerkleTreeIterator: + """A merkle tree iterator based on incoming data. :param leaf: The leaf hash generation function :param parent: The parent hash generation function :param roots: The tree roots """ - leaf: Callable[[MerkleTreeNode, List[MerkleTreeNode]], Hash] - parent: Callable[[MerkleTreeNode, List[MerkleTreeNode]], Hash] + leaf: Callable[[MerkleTreeNode], Hash] + parent: Callable[[MerkleTreeNode, MerkleTreeNode], Hash] roots: List[MerkleTreeNode] = attr.Factory(list) - def next( - self, data: bytes, nodes: Optional[List[MerkleTreeNode]] = None - ) -> List[MerkleTreeNode]: - """Further generate the tree based on the incoming data. + _position: int = 0 + _nodes: List[MerkleTreeNode] = attr.Factory(list) - :param data: Incoming data - :param nodes: Pre-existing nodes + def __attrs_post_init__(self) -> Any: + """Initialise parent and block defaults.""" + try: + index = self.roots[len(self.roots) - 1].index + except IndexError: + index = 0 + + right_span = flat_tree.right_span(index) + self.blocks = (1 + (right_span / 2)) if self.roots else 0 + + for root in self.roots: + if not root.parent: + root.parent = flat_tree.parent(root.index) + + def __iter__(self) -> Iterator: + """The iterator initialisation.""" + return self + + def __next__(self) -> MerkleTreeNode: + """The following node.""" + try: + node = self._nodes[self._position] + except IndexError: + raise StopIteration + + self._position += 1 + + return node + + def __len__(self) -> int: + """The number of nodes stored in the tree.""" + return len(self._nodes) + + def write(self, data: bytes): + """Write a new node to the tree. + + :param data: The new tree data """ - nodes = nodes or [] + index = 2 * self.blocks - index = 2 * (self.blocks + 1) + self.blocks += 1 leaf_node = MerkleTreeNode( index=index, @@ -66,11 +99,10 @@ class MerkleTreeGenerator: data=data, size=len(data), ) - - leaf_node.hash = self.leaf(leaf_node, self.roots) + leaf_node.hash = self.leaf(leaf_node) self.roots.append(leaf_node) - nodes.append(leaf_node) + self._nodes.append(leaf_node) while len(self.roots) > 1: left = self.roots[len(self.roots) - 2] @@ -84,23 +116,11 @@ class MerkleTreeGenerator: new_node = MerkleTreeNode( index=left.parent, parent=flat_tree.parent(left.parent), - hash=self.parent(left, [right]), + hash=self.parent(left, right), size=left.size + right.size, - data=None, + data=b'', ) self.roots[len(self.roots) - 1] = new_node - nodes.append(new_node) - - return nodes - - def __attrs_post_init__(self) -> Any: - """Initialise parent and block defaults.""" - index = self.roots[len(self.roots) - 1].index - right_span = flat_tree.right_span(index) - self.blocks = (1 + (right_span / 2)) if self.roots else 0 - - for root in self.roots: - if not root.parent: - root.parent = flat_tree.parent(root.index) + self._nodes.append(new_node) diff --git a/test/conftest.py b/test/conftest.py index 8c40afc..f0fad3a 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -5,18 +5,18 @@ import pytest @pytest.fixture def leaf(): - def _leaf(node): - return hashlib.sha256(leaf.data).hexdigest() + def _leaf(node, roots=None): + return hashlib.sha256(node.data).hexdigest() return _leaf @pytest.fixture def parent(): - def _parent(left, right): + def _parent(first, second): sha256 = hashlib.sha256() - sha256.update(left) - sha256.update(right) + sha256.update(first.data) + sha256.update(second.data) return sha256.hexdigest() return _parent diff --git a/test/test_generate.py b/test/test_generate.py new file mode 100644 index 0000000..e916bae --- /dev/null +++ b/test/test_generate.py @@ -0,0 +1,64 @@ +"""Merkle tree generation test module.""" + +import hashlib + +import pytest + +from merkle_tree_stream import MerkleTreeIterator, MerkleTreeNode + + +def test_hashes(leaf, parent): + merkle_iter = MerkleTreeIterator(leaf=leaf, parent=parent) + + merkle_iter.write(b'a') + merkle_iter.write(b'b') + + expected_count = 2 + 1 # nodes plus parent + assert len(merkle_iter) == expected_count + + assert next(merkle_iter) == MerkleTreeNode( + index=0, + parent=1, + hash=hashlib.sha256(b'a').hexdigest(), + size=1, + data=b'a', + ) + + assert next(merkle_iter) == MerkleTreeNode( + index=2, + parent=1, + hash=hashlib.sha256(b'b').hexdigest(), + size=1, + data=b'b', + ) + + hashed = hashlib.sha256(b'a') + hashed.update(b'b') + + assert next(merkle_iter) == MerkleTreeNode( + index=1, parent=3, hash=hashed.hexdigest(), size=2, data=b'' + ) + + with pytest.raises(StopIteration): + next(merkle_iter) + + +def test_single_root(leaf, parent): + merkle_iter = MerkleTreeIterator(leaf=leaf, parent=parent) + + merkle_iter.write(b'a') + merkle_iter.write(b'b') + merkle_iter.write(b'c') + merkle_iter.write(b'd') + + assert len(merkle_iter.roots) == 1 + + +def multiple_roots(leaf, parent): + merkle_iter = MerkleTreeIterator(leaf=leaf, parent=parent) + + merkle_iter.write(b'a') + merkle_iter.write(b'b') + merkle_iter.write(b'c') + + assert len(merkle_iter.roots) > 1 diff --git a/test/test_generator.py b/test/test_generator.py deleted file mode 100644 index 7ea3a91..0000000 --- a/test/test_generator.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Generator test module.""" - -import hashlib - -from merkle_tree_stream import MerkleTreeGenerator, MerkleTreeNode - - -def test_hashes(leaf, parent): - stream = MerkleTreeGenerator(leaf=leaf, parent=parent) - - stream.next(b'a') - - first_node = ( - MerkleTreeNode( - index=0, - parent=1, - hash=hashlib.sha256(b'a').hexdigest(), - size=1, - data=b'a', - ), - ) - - stream.next(b'b') - - second_node = ( - MerkleTreeNode( - index=2, - parent=1, - hash=hashlib.sha256(b'b').hexdigest(), - size=1, - data=b'a', - ), - ) - - stream.next(b'c') - - third = hashlib.sha256(b'a') - third.update(b'b') - third_hash = third.hexdigest() - - third_node = ( - MerkleTreeNode(index=1, parent=3, hash=third_hash, size=2, data=b'a'), - ) - - assert stream.nodes == [first_node, second_node, third_node] - - -def test_single_root(leaf, parent): - stream = MerkleTreeGenerator(leaf=leaf, parent=parent) - - stream.next('a') - stream.next('b') - stream.next('c') - stream.next('d') - - assert stream.roots.length == 1 - - -def multiple_roots(leaf, parent): - stream = MerkleTreeGenerator(leaf=leaf, parent=parent) - - stream.next('a') - stream.next('b') - stream.next('c') - - assert stream.roots.length > 1