For another project of mine, I decided that it would be handy to have a tree structure to dictate hierarchy between pieces of data, but I didn't want to unnecessarily keep that data alive if it would have otherwise expired. I searched the standard library and PyPI, but didn't find any relevant packages, so I decided to create one, which I intend to release in the off chance that it's useful for anyone else.
The class WeakTreeNode (WTN) holds a weak reference to a piece of data, and tracks both the root and branches (parent and children, but keeping with the tree metahpor) to establish the tree. A WTN with a root of None is considered top-level, and would generally be the entry point for the tree. I've also included some helpful iteration methods for traversing the tree.
I'm pretty confident this is a solid design, and I've heavily referenced weak structures in the standard library like WeakValueDictionary, but I've only been working in python for ~1.5 years now in my free time after many years away from programming in general, so I'd appreciate some feedback regarding best practices and obvious pitfalls that I've missed.
I'm running in python 3.12.4, for reference, in case I'm using a deprecated feature in a future version, or a new feature has been added that handles something I've implemented.
node.py:
from __future__ import annotations
from abc import ABC, abstractmethod
from collections import deque
from enum import auto, Enum
from typing import Generic, TypeVar, TYPE_CHECKING
from weakref import ref
if TYPE_CHECKING:
from collections.abc import Callable, Iterator
from typing import ClassVar
T = TypeVar("T")
IterT = TypeVar("IterT")
class CleanupMode(Enum):
DEFAULT = auto()
PRUNE = auto()
REPARENT = auto()
NO_CLEANUP = auto()
def _idle(node: WeakTreeNode[T]):
# Intentionally do nothing.
pass
def _prune(node: WeakTreeNode[T]):
if node.root:
node.root._branches.discard(node)
# This will allow the branch to unwind and be gc'd unless the user has another
# reference to any of the nodes somehwere.
node._branches.clear()
def _reparent(node: WeakTreeNode[T]):
if node.root:
node.root._branches.discard(node)
for subnode in node._branches.copy():
subnode.root = node.root
def _get_cleanup_method(
node: WeakTreeNode[T],
cleanup_method: CleanupMode,
) -> Callable[[WeakTreeNode], None]:
match cleanup_method:
case CleanupMode.PRUNE:
return _prune
case CleanupMode.REPARENT:
return _reparent
case CleanupMode.NO_CLEANUP:
return _idle
case _:
root = node.root
if not root:
# If we're top level and ask for default, default to pruning.
return _prune
# Otherwise, find the root's cleanup method.
return _get_cleanup_method(root, root._cleanup_mode)
class WeakTreeNode(Generic[T]):
"""
An object that allows for the formation of data trees that don't form strong
references to its data. WeakTreeNodes can be configured to control the result when
the data in the reference dies.
"""
# These are here to allow use without needing to import the enum
DEFAULT: ClassVar[CleanupMode] = CleanupMode.DEFAULT
PRUNE: ClassVar[CleanupMode] = CleanupMode.PRUNE
REPARENT: ClassVar[CleanupMode] = CleanupMode.REPARENT
NO_CLEANUP: ClassVar[CleanupMode] = CleanupMode.NO_CLEANUP
def __init__(
self,
data: T,
root: WeakTreeNode | None = None,
cleanup_mode: CleanupMode = DEFAULT,
callback: Callable | None = None,
) -> None:
"""
Create a new node for a weakly-referencing tree.
:param data: The data to be stored by the new WeakTreeNode
:param root: The previous node in the tree for the new node, defaults to None,
which indicates a top-level node.
:param cleanup_mode: An enum indicating how the tree should cleanup after
itself when the data reference expires, defaults to DEFAULT, which calls
upon the root node, or prune if the high-level root is also DEFAULT.
:param callback: An optional additional callback function, called when the
data reference expires. Defaults to None.
"""
def _remove(wr: ref, selfref=ref(self)):
self = selfref()
if callback:
callback(wr)
if self:
_get_cleanup_method(self, self._cleanup_mode)(self)
self._data = ref(data, _remove)
self._root: ref[WeakTreeNode[T]] | None = None
self.root = root
self._branches: set[WeakTreeNode[T]] = set()
self._cleanup_mode: CleanupMode = cleanup_mode
@property
def branches(self) -> set[WeakTreeNode[T]]:
"""
A set of nodes that descend from the current node.
"""
return self._branches
@property
def root(self) -> WeakTreeNode[T] | None:
"""
A node that sits higher in the tree than the current node.
If None, the current node is considered top-level.
"""
if self._root:
return self._root()
return None
@root.setter
def root(self, node: WeakTreeNode | None):
if self.root:
self.root._branches.discard(self)
if node:
self._root = ref(node)
node._branches.add(self)
else:
self._root = None
@property
def data(self) -> T | None:
"""
The value stored by the node.
"""
# Dereference our data so the real object can be used.
return self._data()
def add_branch(
self,
data: T,
cleanup_mode: CleanupMode = DEFAULT,
callback: Callable | None = None,
) -> WeakTreeNode[T]:
"""
Creates a new node as a child of the current node, with a weak reference to the
passed value.
Returns the new isntance, so this can be chained without intermediate variables.
:param data: The data to be stored by the new WeakTreeNode
:param cleanup_mode: An enum indicating how the tree should cleanup after
itself when the data reference expires, defaults to DEFAULT, which calls
upon the root node, or prune if the high-level root is also DEFAULT.
:param callback: An optional additional callback function, called when the
data reference expires. Defaults to None.
:return: The newly created node.
"""
return WeakTreeNode(data, self, cleanup_mode, callback)
def breadth(self) -> Iterator[WeakTreeNode[T]]:
"""
Provides a generator that performs a breadth-first traversal of the tree
starting at the current node.
:yield: The next node in the tree, breadth-first.
"""
# The .breadth() isn't strictly needed, but could cause an issue if we decide
# to change what the default iteration mode is.
yield from NodeIterable(self).breadth()
def depth(self) -> Iterator[WeakTreeNode[T]]:
"""
Provides a generator that performs a depth-first traversal of the tree
starting at the current node.
:yield: The next node in the tree, depth-first.
"""
yield from NodeIterable(self).depth()
def to_root(self) -> Iterator[WeakTreeNode[T]]:
"""
Provides a generator that traces the tree back to the furthest root.
:yield: The root node of the previous node.
"""
yield from NodeIterable(self).to_root()
def nodes(self) -> NodeIterable:
"""
Returns an iterable that allows iteration over the nodes of the tree, starting
from the calling node.
"""
return NodeIterable(self)
def values(self) -> ValueIterable[T]:
"""
Returns an iterable that allows iteration over the values of the tree, starting
from the calling node.
"""
return ValueIterable[T](self)
def items(self) -> ItemsIterable[T]:
"""
Returns an iterable that allows iteration over both the nodes and values of the
tree, starting from the calling node.
"""
return ItemsIterable[T](self)
def __iter__(self) -> Iterator[WeakTreeNode[T]]:
"""
Default iteration method, in this case, breadth-first.
:yield: The next node in the tree, breadth-first.
"""
yield from self.breadth()
def __repr__(self) -> str:
return f"WeakTreeNode({self.data}, {self.root})"
class TreeIterable(ABC, Generic[IterT]):
"""
Generic base class for iterating over trees.
"""
def __init__(self, starting_node: WeakTreeNode[T]) -> None:
self._root_node = starting_node
@abstractmethod
def _get_iter_output(self, node: WeakTreeNode) -> IterT:
pass
def breadth(self) -> Iterator[IterT]:
"""
Provides a generator that performs a breadth-first traversal of the tree
starting at the root node of the iterable.
Order is not guaranteed.
"""
queue: deque[WeakTreeNode] = deque([self._root_node])
while queue:
node = queue.popleft()
yield self._get_iter_output(node)
queue.extend(node.branches)
def depth(self) -> Iterator[IterT]:
"""
Provides a generator that performs a depth-first traversal of the tree,
starting from the root node of the iterable.
Order is not guaranteed.
"""
stack: list[WeakTreeNode] = [self._root_node]
while stack:
node = stack.pop()
yield self._get_iter_output(node)
stack.extend(node.branches)
def to_root(self) -> Iterator[IterT]:
"""
Provides a generator that traces the tree back to the furthest root.
"""
node: WeakTreeNode | None = self._root_node
while node:
yield self._get_iter_output(node)
node = node.root
def __iter__(self) -> Iterator[IterT]:
"""
Provides a default iterator for the node. By default, iterates by breadth-first.
"""
yield from self.breadth()
class NodeIterable(TreeIterable[WeakTreeNode]):
"""
Variant of TreeIterator that provides the nodes of the tree themselves when
iterated over.
"""
def _get_iter_output(self, node: WeakTreeNode[T]) -> WeakTreeNode[T]:
return node
class ValueIterable(TreeIterable[T | None]):
"""
Variant of TreeIterator that provides the values of the nodes of the tree when
iterated over.
"""
def _get_iter_output(self, node: WeakTreeNode[T]) -> T | None:
return node.data
class ItemsIterable(TreeIterable[tuple[WeakTreeNode[T], T | None]]):
"""
Variant of TreeIterable that provides pairs of nodes with their values when
iterated over.
"""
def _get_iter_output(self, node: WeakTreeNode) -> tuple[WeakTreeNode, T | None]:
return node, node.data
I also have two files for unit tests. Apologies for the weird imports on those, I keep my unit tests in a separate folder from src, and this causes import issues that I solve with said weird imports.
test_iterables.py:
from __future__ import annotations
from collections import deque
import pathlib
import sys
import unittest
sys.path.append(str(pathlib.Path.cwd()))
from src.weaktree.node import ( # noqa: E402
WeakTreeNode,
NodeIterable,
ValueIterable,
ItemsIterable,
)
class TestObject:
def __init__(self, data):
self.data = data
def __repr__(self) -> str:
return f"TestObject({str(self.data)})"
test_data: dict[str, TestObject] = {
"root": TestObject("Root"),
"1": TestObject("1"),
"2": TestObject("2"),
"3": TestObject("3"),
"4": TestObject("4"),
"5": TestObject("5"),
"6": TestObject("6"),
"7": TestObject("7"),
"8": TestObject("8"),
"9": TestObject("9"),
}
root = WeakTreeNode(test_data["root"])
branch1 = root.add_branch(test_data["1"])
branch4 = branch1.add_branch(test_data["4"])
branch8 = branch4.add_branch(test_data["8"])
branch9 = branch8.add_branch(test_data["9"])
branch5 = branch1.add_branch(test_data["5"])
branch2 = root.add_branch(test_data["2"])
branch6 = branch2.add_branch(test_data["6"])
branch3 = root.add_branch(test_data["3"])
branch7 = branch3.add_branch(test_data["7"])
class Test_NodeIterable(unittest.TestCase):
def setUp(self) -> None:
self.iterable = NodeIterable(root)
def test_breadth_iterator(self):
queue = deque()
for node in self.iterable.breadth():
no_root = node.root is None
queued_root = node.root in queue
self.assertTrue(no_root or queued_root)
self.assertIsInstance(node, WeakTreeNode)
if queued_root:
while queue[0] is not node.root:
queue.popleft()
queue.append(node)
def test_depth_iterator(self):
stack = []
for node in self.iterable.depth():
# Determine if our node is top level, or a decendant of one in the stack
no_root = node.root is None
stack_root = node.root in stack
self.assertTrue(no_root or stack_root)
self.assertIsInstance(node, WeakTreeNode)
if stack_root:
# For a descendant, remove the chain until we get to the node's
# parent. If we end up out of order, this is what will cause our
# test to fail
while stack[-1] is not node.root:
stack.pop()
stack.append(node)
def test_to_root(self) -> None:
iterable = NodeIterable(branch9)
previous_node: WeakTreeNode | None = None
for node in iterable.to_root():
if previous_node:
self.assertIs(node, previous_node.root)
self.assertIsInstance(node, WeakTreeNode)
previous_node = node
class Test_ValueIterable(unittest.TestCase):
def setUp(self) -> None:
self.iterable = ValueIterable[TestObject](root)
def test_breadth_iterator(self):
# We already have a test for iterator order in Test_NodeIterator, so we're just
# doing isntance checking here.
for value in self.iterable.breadth():
self.assertIsInstance(value, TestObject)
def test_depth_iterator(self):
for value in self.iterable.depth():
self.assertIsInstance(value, TestObject)
def test_to_root(self) -> None:
iterable = ValueIterable[TestObject](branch9)
for value in iterable.to_root():
self.assertIsInstance(value, TestObject)
class Test_ItemIterable(unittest.TestCase):
def setUp(self) -> None:
self.iterable = ItemsIterable[TestObject](root)
def test_breadth_iterator(self):
# We already have a test for iterator order in Test_NodeIterator, so we're just
# doing isntance checking here.
# Unpack will fail if it isn't a tuple, so we can directly to instance checking
# on the resultant values.
for node, value in self.iterable.breadth():
self.assertIsInstance(node, WeakTreeNode)
self.assertIsInstance(value, TestObject)
def test_depth_iterator(self):
for node, value in self.iterable.depth():
self.assertIsInstance(node, WeakTreeNode)
self.assertIsInstance(value, TestObject)
def test_to_root(self) -> None:
iterable = ItemsIterable[TestObject](branch9)
for node, value in iterable.to_root():
self.assertIsInstance(node, WeakTreeNode)
self.assertIsInstance(value, TestObject)
if __name__ == "__main__":
unittest.main()
test_node.py:
from __future__ import annotations
import pathlib
import sys
import unittest
from weakref import ref
sys.path.append(str(pathlib.Path.cwd()))
from src.weaktree.node import ( # noqa: E402
WeakTreeNode,
ItemsIterable,
NodeIterable,
ValueIterable,
)
class TestObject:
def __init__(self, data):
self.data = data
def __repr__(self) -> str:
return f"TestObject({str(self.data)})"
class TestNode(unittest.TestCase):
def setUp(self) -> None:
test_data: dict[str, TestObject] = {
"root": TestObject("Root"),
"1": TestObject("1"),
"2": TestObject("2"),
"3": TestObject("3"),
"4": TestObject("4"),
"5": TestObject("5"),
"6": TestObject("6"),
"7": TestObject("7"),
"8": TestObject("8"),
"9": TestObject("9"),
}
self.test_data = test_data
self.root = WeakTreeNode(test_data["root"])
branch1 = self.root.add_branch(test_data["1"])
branch4 = branch1.add_branch(test_data["4"])
branch8 = branch4.add_branch(test_data["8"])
branch8.add_branch(test_data["9"])
branch1.add_branch(test_data["5"])
branch2 = self.root.add_branch(test_data["2"])
branch2.add_branch(test_data["6"])
branch3 = self.root.add_branch(test_data["3"])
branch3.add_branch(test_data["7"])
# Note: These excess vars all descope after setup, so we don't have to worry
# about excess references.
# We could also do this by chaining add_branch, but that actually becomes
# _less_ readable.
def test_nodes(self):
self.assertIsInstance(self.root.nodes(), NodeIterable)
def test_values(self):
self.assertIsInstance(self.root.values(), ValueIterable)
def test_items(self):
self.assertIsInstance(self.root.items(), ItemsIterable)
def test_callback(self):
# Add a custom callback to a node that will get cleaned up.
global callback_ran
callback_ran = False
def callback(wr):
global callback_ran
callback_ran = True
ephemeral_data = TestObject("NotLongForThisWorld")
WeakTreeNode(ephemeral_data, self.root, callback=callback)
del ephemeral_data
self.assertTrue(callback_ran)
def test_prune(self):
ephemeral_data = {
"1": TestObject("E1"),
"2": TestObject("E2"),
"3": TestObject("E3"),
}
branch_e1 = WeakTreeNode(ephemeral_data["1"], self.root)
branch_e2 = branch_e1.add_branch(ephemeral_data["2"])
branch_e3_wr = ref(branch_e2.add_branch(ephemeral_data["3"]))
branch_e2_wr = ref(branch_e2)
del branch_e2 # Ensure there's no strong reference to e2
# Ensure our nodes exist
self.assertIsNotNone(branch_e2_wr())
self.assertIsNotNone(branch_e3_wr())
ephemeral_data.pop("1")
# This should cause branch_e1 to dissolve, and unwind branch_e2
self.assertIsNone(branch_e2_wr())
self.assertIsNone(branch_e3_wr())
def test_reparent(self):
ephemeral_data = {
"1": TestObject("E1"),
"2": TestObject("E2"),
"3": TestObject("E3"),
}
branch_e1 = WeakTreeNode(
ephemeral_data["1"], self.root, cleanup_mode=WeakTreeNode.REPARENT
)
branch_e2 = branch_e1.add_branch(ephemeral_data["2"])
branch_e3_wr = ref(branch_e2.add_branch(ephemeral_data["3"]))
branch_e2_wr = ref(branch_e2)
del branch_e2 # Ensure there's no strong reference to e2
# Ensure our nodes exist
self.assertIsNotNone(branch_e2_wr())
self.assertIsNotNone(branch_e3_wr())
ephemeral_data.pop("1")
# We want our nodes to still exist
branch_e2 = branch_e2_wr()
self.assertIsNotNone(branch_e2)
self.assertIsNotNone(branch_e3_wr())
assert branch_e2 # for the static type checker
# e2 should now be a child of self.root
self.assertIs(branch_e2.root, self.root)
def test_no_cleanup(self):
ephemeral_data = {
"1": TestObject("E1"),
"2": TestObject("E2"),
"3": TestObject("E3"),
}
branch_e1 = WeakTreeNode(
ephemeral_data["1"], self.root, cleanup_mode=WeakTreeNode.NO_CLEANUP
)
branch_e2 = branch_e1.add_branch(ephemeral_data["2"])
branch_e3_wr = ref(branch_e2.add_branch(ephemeral_data["3"]))
branch_e2_wr = ref(branch_e2)
del branch_e2 # Ensure there's no strong reference to e2
# Ensure our nodes exist
self.assertIsNotNone(branch_e2_wr())
self.assertIsNotNone(branch_e3_wr())
ephemeral_data.pop("1")
# We want our nodes to still exist
branch_e2 = branch_e2_wr()
self.assertIsNotNone(branch_e2)
self.assertIsNotNone(branch_e3_wr())
assert branch_e2 # for the static type checker
# e1 should still exist and still be the parent of e2
self.assertIs(branch_e2.root, branch_e1)
# e1 should be empty, or rather the weakref should return None
self.assertIsNone(branch_e1.data)
if __name__ == "__main__":
unittest.main()