Finding Two Numbers of Given Sum in Binary Search Tree


Given a binary search tree and a target, find out if there are any two numbers in the BST that sums up to this target.

Given a Binary Search Tree and a target number, return true if there exist two elements in the BST such that their sum is equal to the given target.

Example 1:

Input:

 
    5
   / \
  3   6
 / \   \
2   4   7

Target = 9

Output: True
Example 2:

Input:

    5
   / \
  3   6
 / \   \
2   4   7

Target = 28

Output: False

Depth First Search with Hash Set

We can go through the tree with O(N) time complexity using Depth First Search (DFS). Each node of the BST will be visited at most once in the worst case. When we visit the node, we remember the number in the hash set so that we can check the existence of the remaining number in the hash set. The space complexity is also O(N) as the set will grow maximum with N items (N nodes of the BST Tree). The DFS is usually implemented in recursion (shorter code, and the stack is managed by the compiler).

The C++ implementation of this DFS with Hashset approach is (40ms):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    bool findTarget(TreeNode* root, int k) {
        unordered_set<int> set;
        return find(root, k, set);
    }
    
private:
    bool find(TreeNode* root, int k, unordered_set<int> &set) {
        if (root == NULL) return false;
        if (set.count(k - root->val)) return true;
        set.insert(root->val);
        return find(root->left, k, set) || find(root->right, k, set);
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    bool findTarget(TreeNode* root, int k) {
        unordered_set<int> set;
        return find(root, k, set);
    }
    
private:
    bool find(TreeNode* root, int k, unordered_set<int> &set) {
        if (root == NULL) return false;
        if (set.count(k - root->val)) return true;
        set.insert(root->val);
        return find(root->left, k, set) || find(root->right, k, set);
    }
};

The corresponding Java implementation is quite similar (19ms):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public boolean findTarget(TreeNode root, int k) {
        Set<Integer> set = new HashSet<integer>();
        return find(root, k, set);
    }
    
    private boolean find(TreeNode root, int k, Set<Integer> set) {
        if (root == null) return false;
        if (set.contains(k - root.val)) return true;
        set.add(root.val);
        return find(root.left, k, set) || find(root.right, k, set);
    }
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public boolean findTarget(TreeNode root, int k) {
        Set<Integer> set = new HashSet<integer>();
        return find(root, k, set);
    }
    
    private boolean find(TreeNode root, int k, Set<Integer> set) {
        if (root == null) return false;
        if (set.contains(k - root.val)) return true;
        set.add(root.val);
        return find(root.left, k, set) || find(root.right, k, set);
    }
}

Breadth First Search with Hash Set

Similarly, we can implement a BFS (Breadth First Search) using a queue. The traversal of the BST tree will be by levels by levels and all others are quite similar to the above DFS approach. The time and space complexity is also O(N). The BFS is usually non-recursive.

C++ implementation that runs 32ms:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    bool findTarget(TreeNode* root, int k) {
        if (root == NULL) return false;
        unordered_set<int> data;
        queue<treenode *> q;
        q.push(root);
        while (!q.empty()) {
            auto p = q.front();
            q.pop();
            if (data.count(p->val)) return true;
            data.insert(k - p->val);
            if (p->left) q.push(p->left);
            if (p->right) q.push(p->right);
        }
        return false;
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    bool findTarget(TreeNode* root, int k) {
        if (root == NULL) return false;
        unordered_set<int> data;
        queue<treenode *> q;
        q.push(root);
        while (!q.empty()) {
            auto p = q.front();
            q.pop();
            if (data.count(p->val)) return true;
            data.insert(k - p->val);
            if (p->left) q.push(p->left);
            if (p->right) q.push(p->right);
        }
        return false;
    }
};

Converting to Java takes around 15ms:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public boolean findTarget(TreeNode root, int k) {
        if (root == null) return false;
        Queue<TreeNode> q = new LinkedList<TreeNode>();
        HashSet<Integer> data = new HashSet<Integer>();
        q.add(root);
        while(q.size() > 0) {
            TreeNode p = q.poll();
            if (data.contains(p.val)) return true;
            data.add(k - p.val);
            if (p.left != null) q.add(p.left);
            if (p.right != null) q.add(p.right);
        }
        return false;
    }    
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public boolean findTarget(TreeNode root, int k) {
        if (root == null) return false;
        Queue<TreeNode> q = new LinkedList<TreeNode>();
        HashSet<Integer> data = new HashSet<Integer>();
        q.add(root);
        while(q.size() > 0) {
            TreeNode p = q.poll();
            if (data.contains(p.val)) return true;
            data.add(k - p.val);
            if (p.left != null) q.add(p.left);
            if (p.right != null) q.add(p.right);
        }
        return false;
    }    
}

Please note that the Queue in Java uses a method poll instead of pop.

Converting to Sorted Array with Inorder Traversal

So far, we have not made use of the fact that the tree is BST, meaning that the left sub trees are smaller than the parent nodes, and the right sub trees have values bigger than the parent node. If we do a Inorder traversal, we will have a strictly sorted array in the non-descending order.

Then, we have two pointers initially pointing to the both ends of the array, comparing the sum to the target, and move corresponding pointer towards each other until we have a solution or simply the sum is not found. The complexity is also O(N) as we still require O(N) time to convert the BST tree to the sorted array whilst the second step finding sum is O(N) the worst case requiring visiting the N nodes.

The C++ takes around 24ms.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:   
    bool findTarget(TreeNode* root, int k) {
        vector<int> arr;
        dfs(root, arr);
        int i = 0, j = arr.size() - 1;
        while (i < j) {
            if (arr[i] + arr[j] == k) {
                return true;
            }
            if (arr[i] + arr[j] > k) {
                j --;
            } else {
                i ++;
            }
        }
        return false;
    }
    
private:
    void dfs(TreeNode* root, vector<int> &arr) {
        if (root == NULL) return;        
        dfs(root->left, arr);    
        arr.push_back(root->val);
        dfs(root->right, arr);    
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:   
    bool findTarget(TreeNode* root, int k) {
        vector<int> arr;
        dfs(root, arr);
        int i = 0, j = arr.size() - 1;
        while (i < j) {
            if (arr[i] + arr[j] == k) {
                return true;
            }
            if (arr[i] + arr[j] > k) {
                j --;
            } else {
                i ++;
            }
        }
        return false;
    }
    
private:
    void dfs(TreeNode* root, vector<int> &arr) {
        if (root == NULL) return;        
        dfs(root->left, arr);    
        arr.push_back(root->val);
        dfs(root->right, arr);    
    }
};

and equivalent Java solution is 15ms.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public boolean findTarget(TreeNode root, int k) {
        ArrayList<Integer> arr = new ArrayList<Integer>();
        dfs(root, arr);
        int i = 0, j = arr.size() - 1;
        while (i < j) {
            if (arr.get(i) + arr.get(j) == k) return true;
            if (arr.get(i) + arr.get(j) > k) {
                j --;
            } else {
                i ++;
            }
        }
        return false;
    }
    
    private void dfs(TreeNode root, ArrayList<Integer> arr) {
        if (root == null) return;
        dfs(root.left, arr);
        arr.add(root.val);
        dfs(root.right, arr);
    }
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public boolean findTarget(TreeNode root, int k) {
        ArrayList<Integer> arr = new ArrayList<Integer>();
        dfs(root, arr);
        int i = 0, j = arr.size() - 1;
        while (i < j) {
            if (arr.get(i) + arr.get(j) == k) return true;
            if (arr.get(i) + arr.get(j) > k) {
                j --;
            } else {
                i ++;
            }
        }
        return false;
    }
    
    private void dfs(TreeNode root, ArrayList<Integer> arr) {
        if (root == null) return;
        dfs(root.left, arr);
        arr.add(root.val);
        dfs(root.right, arr);
    }
}

Surprisingly, the Java solutions are slightly faster than the C++ solutions on leetcode online judges.

–EOF (The Ultimate Computing & Technology Blog) —

GD Star Rating
loading...
1144 words
Last Post: You need a Numeric Keypad to Start with Your HHKB!
Next Post: The Javascript Function to Compare Version Number Strings

The Permanent URL is: Finding Two Numbers of Given Sum in Binary Search Tree

Leave a Reply