"""A merkle tree generator.""" from typing import Any, Callable, Iterator, List import attr from flat_tree import FlatTreeAccessor from merkle_tree_stream.node import MerkleTreeNode EMPTY_DATA = b"" EMPTY_HASH = None __all__ = ["MerkleTreeGenerator"] flat_tree = FlatTreeAccessor() @attr.s(auto_attribs=True) class MerkleTreeGenerator: """A merkle tree generator. :param leaf: The leaf hash generation function :param parent: The parent hash generation function :param roots: The tree roots """ leaf: Callable[[MerkleTreeNode], bytes] parent: Callable[[MerkleTreeNode, MerkleTreeNode], bytes] roots: List[MerkleTreeNode] = attr.Factory(list) _position: int = 0 _nodes: List[MerkleTreeNode] = attr.Factory(list) def __attrs_post_init__(self) -> Any: """Initialise parent and block defaults.""" try: index = self.roots[len(self.roots) - 1].index except IndexError: index = 0 right_span = flat_tree.right_span(index) self.blocks = (1 + (right_span / 2)) if self.roots else 0 for root in self.roots: if not root.parent: root.parent = flat_tree.parent(root.index) def __iter__(self) -> Iterator: """The iterator initialisation.""" return self def __next__(self) -> MerkleTreeNode: """The following node.""" try: node = self._nodes[self._position] except IndexError: raise StopIteration self._position += 1 return node def __len__(self) -> int: """The number of nodes stored in the tree.""" return len(self._nodes) # TODO(decentral1se): we need to take pass on async capability. Please see # https://datprotocol.github.io/book/ch02-02-merkle-tree-stream.html#async def write(self, data: bytes): """Write a new node to the tree and compute the new hashes. :param data: The new tree data """ index = 2 * self.blocks self.blocks += 1 leaf_node = MerkleTreeNode( index=index, parent=flat_tree.parent(index), hash=EMPTY_HASH, data=data, size=len(data), ) leaf_node.hash = self.leaf(leaf_node) self.roots.append(leaf_node) self._nodes.append(leaf_node) while len(self.roots) > 1: left = self.roots[len(self.roots) - 2] right = self.roots[len(self.roots) - 1] if left.parent != right.parent: break self.roots.pop() new_node = MerkleTreeNode( index=left.parent, parent=flat_tree.parent(left.parent), hash=self.parent(left, right), size=left.size + right.size, data=EMPTY_DATA, ) self.roots[len(self.roots) - 1] = new_node self._nodes.append(new_node)