How to Count Univalue Subtrees in a Binary Tree?


Given a binary tree, count the number of uni-value subtrees. A Uni-value subtree means all nodes of the subtree have the same value.

Example :
Input: root = [5,1,5,5,5,null,5]

              5
             / \
            1   5
           / \   \
          5   5   5

Output: 4

Univalue Sub Binary Trees Algorithm using Depth First Search

To count the number of uni-value sub-trees, we can use depth first search algorithm (DFS). At each recursive call, we pass the parent node so that we can check if its parent node satisfy the uni-value sub tree. A uni-value subtree should satisfy that its children nodes are also uni-value sub trees.

At terminal calls of recursion, we define that NULL node is uni-value. The Following C++ implementation is O(N) time complexity and O(H) space complexity where N and H are the number of the nodes and the height of the tree (in worst case it would become N) respectively.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int countUnivalSubtrees(TreeNode* root) {
        dfs(root, NULL);
        return count;
    }
    
    bool dfs(TreeNode* root, TreeNode* parent) {
        if (root == NULL) return true; // null node is uni-value
        // use | instead of || to avoid boolean shortcut optimisation
        if (!dfs(root->left, root) | !dfs(root->right, root)) {
            return false;
        }
        count ++;
        // univalue check: the value with parent
        return parent == NULL || root->val == parent->val;
    }
 
private:
    int count = 0;
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int countUnivalSubtrees(TreeNode* root) {
        dfs(root, NULL);
        return count;
    }
    
    bool dfs(TreeNode* root, TreeNode* parent) {
        if (root == NULL) return true; // null node is uni-value
        // use | instead of || to avoid boolean shortcut optimisation
        if (!dfs(root->left, root) | !dfs(root->right, root)) {
            return false;
        }
        count ++;
        // univalue check: the value with parent
        return parent == NULL || root->val == parent->val;
    }

private:
    int count = 0;
};

The Java version of counting the uni-value sub tree in binary tree is similar, as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int countUnivalSubtrees(TreeNode root) {
        dfs(root, null);
        return count;
    }
    
    private boolean dfs(TreeNode root, TreeNode parent) {
        if (root == null) {
            return true;
        }
        if ((!dfs(root.left, root)) | (!dfs(root.right, root))) {
            return false;
        }
        
        count ++;
        return parent == null || parent.val == root.val;
    }
    
    private int count = 0;
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int countUnivalSubtrees(TreeNode root) {
        dfs(root, null);
        return count;
    }
    
    private boolean dfs(TreeNode root, TreeNode parent) {
        if (root == null) {
            return true;
        }
        if ((!dfs(root.left, root)) | (!dfs(root.right, root))) {
            return false;
        }
        
        count ++;
        return parent == null || parent.val == root.val;
    }
    
    private int count = 0;
}

To avoid the boolean short-circuit optimisation, we need to use boolean bit OR (single |) instead of || (double) to avoid branch check cut off (as we need to increment the counter).

The alternative solution – also DFS, would be a bit more intuitive, as illustrated in Java implementation below.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int countUnivalSubtrees(TreeNode root) {
        if (root == null) return 0;
        dfs(root);
        return count;
    }
    
    private boolean dfs(TreeNode root) {
        if (root.left == null && root.right == null) {
            count ++;
            return true;
        }
        boolean valid = true;
        if (root.left != null) {
            valid = dfs(root.left) && valid && root.left.val == root.val;
        }
        if (root.right != null) {
            valid = dfs(root.right) && valid && root.right.val == root.val;
        }
        if (valid) count ++;
        return valid;
    }
    
    private int count = 0;
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int countUnivalSubtrees(TreeNode root) {
        if (root == null) return 0;
        dfs(root);
        return count;
    }
    
    private boolean dfs(TreeNode root) {
        if (root.left == null && root.right == null) {
            count ++;
            return true;
        }
        boolean valid = true;
        if (root.left != null) {
            valid = dfs(root.left) && valid && root.left.val == root.val;
        }
        if (root.right != null) {
            valid = dfs(root.right) && valid && root.right.val == root.val;
        }
        if (valid) count ++;
        return valid;
    }
    
    private int count = 0;
}

And the C++ implementation:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int countUnivalSubtrees(TreeNode* root) {
        if (root == NULL) return 0;        
        int count = 0;
        dfs(root, count);
        return count;
    }
    
    bool dfs(TreeNode* root, int &count) {
        if (root->left == NULL && root->right == NULL) {
            count ++;
            return true;
        }
        bool valid = true;
        if (root->left != NULL) {
            valid = dfs(root->left, count) && valid && root->left->val == root->val;
        }
        if (root->right != NULL) {
            valid = dfs(root->right, count) && valid && root->right->val == root->val;
        }
        if (valid) {
            count ++;
        }
        return valid;
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int countUnivalSubtrees(TreeNode* root) {
        if (root == NULL) return 0;        
        int count = 0;
        dfs(root, count);
        return count;
    }
    
    bool dfs(TreeNode* root, int &count) {
        if (root->left == NULL && root->right == NULL) {
            count ++;
            return true;
        }
        bool valid = true;
        if (root->left != NULL) {
            valid = dfs(root->left, count) && valid && root->left->val == root->val;
        }
        if (root->right != NULL) {
            valid = dfs(root->right, count) && valid && root->right->val == root->val;
        }
        if (valid) {
            count ++;
        }
        return valid;
    }
};

This approach has the same complexity as the first approach. All algorithms are implemented using Recursion.

Count the Uni-value Subtrees using a Helper function

We can define a isUniTree recursive function that will return true if all the values in the tree are equal to the passed parameter. Then, the counting problem can be easily solved.

The total number of unival subtrees can be recursively categorized into three: the unival tree has the current root, or its left child, or its right child as root.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int countUnivalSubtrees(TreeNode* root) {
        if (root == nullptr) return 0;
        int res = 0;
        if (isUniTree(root->left, root->val) &&
           (isUniTree(root->right, root->val))) {
            res ++;
        }
        return countUnivalSubtrees(root->left) +
            countUnivalSubtrees(root->right) + res;
    }
    
private:
    bool isUniTree(TreeNode* root, int val) {
        if (root == nullptr) return true;
        return (root->val == val) &&
            isUniTree(root->left, val) &&
            isUniTree(root->right, val);
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int countUnivalSubtrees(TreeNode* root) {
        if (root == nullptr) return 0;
        int res = 0;
        if (isUniTree(root->left, root->val) &&
           (isUniTree(root->right, root->val))) {
            res ++;
        }
        return countUnivalSubtrees(root->left) +
            countUnivalSubtrees(root->right) + res;
    }
    
private:
    bool isUniTree(TreeNode* root, int val) {
        if (root == nullptr) return true;
        return (root->val == val) &&
            isUniTree(root->left, val) &&
            isUniTree(root->right, val);
    }
};

–EOF (The Ultimate Computing & Technology Blog) —

GD Star Rating
loading...
857 words
Last Post: The Secure Shell (SSH) Chrome Extension developed by Google
Next Post: SQL Algorithm to Compute Shortest Distance in a Plane

The Permanent URL is: How to Count Univalue Subtrees in a Binary Tree?

Leave a Reply