Teaching Kids Programming – Tree Detection Algorithm via Union Find + Disjoint Set (Determine a Binary Tree)


Teaching Kids Programming: Videos on Data Structures and Algorithms

You are given two lists of integers left and right, both of them the same length and representing a directed graph. left[i] is the index of node i’s left child and right[i] is the index of node i’s right child. A null child is represented by -1. Return whether left and right represents a binary tree.

Constraints
n ≤ 100,000 where n is the length of left and right
Example 1
Input

1
2
left = [1, -1, 3, -1]
right = [2, -1, -1, -1]
left = [1, -1, 3, -1]
right = [2, -1, -1, -1]

Output
True

Example 2
Input

1
2
left = [0]
right = [0]
left = [0]
right = [0]

Output
False
Explanation
This is a circular node.

Hints:
Is Tree a DAG (Directed Acyclic Graph)?
Tree is not disjoint.
The trick here is to find the root of the tree!
DAG with an exactly single source, no node with in-degree > 1.
What about topological Sorting?

Tree Detection Algorithm via Disjoint Set + Union Find Algorithm (Determine a Binary Tree)

Disjoint Set is a graph data structure that is handy to check if a graph has cycles and also count how many connected components. The following Disjoint Set class compresses the path so that the time complexity is amortized O(1) constant. And it prefers to merge smaller group into a bigger one in order to minimize the total path lengths.

The Disjoint Set data structure is the fundamental of Union Find algorithm that can be used to solve many interesting Graph problems.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class DisjointSet(object):
    def __init__(self, n):
        self.par = list(range(n))
        self.count = n
        self.sz = [1] * n
 
    def union(self, a, b):
        pa = self.find(a)
        pb = self.find(b)
        if pa == pb:
            return False
        if self.sz[pa] < self.sz[pb]:
            pa, pb = pb, pa        
        self.par[pb] = pa
        self.sz[pa] += self.sz[pb]
        self.count -= 1
        return True
 
    def find(self, a):
        if self.par[a] != a:
            self.par[a] = self.find(self.par[a])
        return self.par[a]
 
    @property
    def size(self):
        return self.count
class DisjointSet(object):
    def __init__(self, n):
        self.par = list(range(n))
        self.count = n
        self.sz = [1] * n

    def union(self, a, b):
        pa = self.find(a)
        pb = self.find(b)
        if pa == pb:
            return False
        if self.sz[pa] < self.sz[pb]:
            pa, pb = pb, pa        
        self.par[pb] = pa
        self.sz[pa] += self.sz[pb]
        self.count -= 1
        return True

    def find(self, a):
        if self.par[a] != a:
            self.par[a] = self.find(self.par[a])
        return self.par[a]

    @property
    def size(self):
        return self.count

With this, we can use the Disjoint Set to check violations of: not being connected in one piece, and having a cycle. However, it does not help detecting that a node having more than 1 paraent. We can count the indegrees of a node (stored in a dictionary) to rule out this.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
class Solution:
    def solve(self, left, right):
        n = len(left)
        ds = DisjointSet(n)
        ind = defaultdict(int)
        for i in range(n):
            # no cycles
            if left[i] != -1 and not ds.union(left[i], i):
                return False
            if right[i] != -1 and not ds.union(right[i], i):
                return False         
            # indegree checks: a node can have at most one indegree 
            if left[i] != -1:
                ind[left[i]] += 1
                if ind[left[i]] > 1:
                    return False
            if right[i] != -1:
                ind[right[i]] += 1
                if ind[right[i]] > 1:
                    return False
        return ds.size == 1
class Solution:
    def solve(self, left, right):
        n = len(left)
        ds = DisjointSet(n)
        ind = defaultdict(int)
        for i in range(n):
            # no cycles
            if left[i] != -1 and not ds.union(left[i], i):
                return False
            if right[i] != -1 and not ds.union(right[i], i):
                return False         
            # indegree checks: a node can have at most one indegree 
            if left[i] != -1:
                ind[left[i]] += 1
                if ind[left[i]] > 1:
                    return False
            if right[i] != -1:
                ind[right[i]] += 1
                if ind[right[i]] > 1:
                    return False
        return ds.size == 1

Tree Detection Algorithms

–EOF (The Ultimate Computing & Technology Blog) —

GD Star Rating
a WordPress rating system
862 words
Last Post: Teaching Kids Programming - Tree Detection via Depth First Search Algorithm (Determine a Binary Tree via Recursion)
Next Post: Teaching Kids Programming - Inorder Traversal Algorithm to Convert Binary Search Tree to Increasing Order Search Tree

The Permanent URL is: Teaching Kids Programming – Tree Detection Algorithm via Union Find + Disjoint Set (Determine a Binary Tree)

Leave a Reply