Tree Counting, A Use Case For Disjoint Sets
In an interview I got asked to implement an algorithm that counts the number of trees in a forest. My implementation was far from perfect so later I tried to improve the algorithm as much as I could and the following is what I have got for the moment but before I begin let’s define some terms.
A Forest is a graph for which we have the number of nodes and a list of edges.
A Tree is a component of the forest graph where every node of the tree is connected to any other via only a single path.
So if a component of the graph has a loop in it, it cannot be counted as a tree. As an example consider the graph expressed with this edges. 1 -> 2 / 1 -> 3 / 4 -> 5 / 4 -> 6 / 5 -> 6 and node count of 7
With this definition the component in yellow is a tree but not the one in red. Note that the green component made of a single node with no edge is considered a tree too.
The Naive Algorithm
Counting the number of trees should not be that hard. After all we can always start with the first edge and add up all connected edges until we have the full component and then move to the next tree. To check if the component has any loops we have to check while adding the edges incrementally to make sure no edge has both ends already inside the visited nodes for the current component. While this algorithm works ok for small graphs we run very fast into performance problems for bigger graphs.
The problem with this algorithm is that we have to keep a list of trees as we process edges and we have to check every single tree we already have found to see if they match against every single edge. This would increase the memory usage of the algorithm dramatically and make it very slow. To give you an indication for a test set of 6 million edges, the most optimized version of this algorithm I could come up with during interview (Partitioning the graph, Hash lookups, Ordering edges so that we never need to do a look behind etc.) would crash with out of memory error after around ten seconds and consuming about 1.5 GBs of memory.
Turns out there is a great data structure dedicated to just that. Partitioning sets so that each element belongs to only one partition. Here is the Wikipedia definition.
a disjoint-set data structure, also called a union–find data structure or merge–find set, is a data structure that keeps track of a set of elements partitioned into a number of disjoint (nonoverlapping) subsets.
The main idea here is to create a data structure which is efficient for two operations namely merging partitions and finding what partition a certain element belongs to. In our case we want to merge two trees if we find and edge that connects one to another. We also need to find what trees or tree a certain edge belongs to by checking each node.
This data structure represents the set as a graph. Let’s say at the beginning we have 5 elements in our set. We shall have 5 partitions since at the beginning each element is in its own partition.
Each partition now has exactly one element and we call that element the root of the partition used to identify partitions. These elements are the roots since they don’t point to any other element. Now let’s say I want to merge 1 and two. All I have to do is to points 2’s root to 1’s root (for efficiency we try to point the smaller partition to the bigger one but here they are both of size 1)
Now we have only 4 partitions and the cost of merge was only a single change for 2 to point to 1. Note that 1 is considered the root and therefore the identity of the partition. We can merge 3 and 4 too and end up with this:
Now if we want to know what partition the element 4 belongs to we have to travel all the way up to the root which in this case is 1 (so the element belongs to partition 1). 5 is its own root so it belongs to partition 5 and we have two distinct partitions in this graph. This implementation is called fast union since the union is O(1) in complexity but find depends on the depth of the tree(still very fast). There is a variation of this data structure that makes find faster by making a merge slower. In the alternative implementation called fast find each element in a partition points directly to the root so the find operation is O(1). When merging two partitions we have to update every single element in that partition to point to the new root so the merge is slower
If you are interested in more details about the design of this data structure I recommend the great course of Jonathan Shewchuk from Berkeley that you can find on youtube. I am going to show you an ingenious array based implementation of this data structure and then show you how we can use that to implement a solver to our forest tree counting problem.
The Array Based Implementation
In the array based implementation the partitions are stored in an array. to represent n elements we need a array of length n. All array elements are initially set to -1 to reflect the fact that they are all roots. For each root element we can choose to store the number of the elements inside the partition inside the root element. in that case we make the count negative to distinguish it from a link to another element.In other words if an array element number x contains -1 its the root and only element inside its partition. if its -10 it the root of a partition that has 9 other elements beside root and if its 42 it points to the index 42 as its parent.
This is the initial array and what it will look like if we merge the partitions 0 and 4 we get:
The element 0 has -2 since the size of the partition has increased by one. the element 4 has the value 0 the value points to the array as parent. Now if we merge element 1 and 4 the index zero becomes -3 since the size of the partition has increased and what happens to element 1 depends on if the implementation is fast-find (in which case we put 0 to point to the partition’s root) or fast-merge (in which we put 4 inside element 1)