How to Find the Kth Smallest Element in a BST Tree Using Java/C++?


Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

You may assume k is always valid, 1 ≤ k ≤ BST’s total elements.

Example 1:
Input: root = [3,1,4,null,2], k = 1

1
2
3
4
5
   3
  / \
 1   4
  \
   2
   3
  / \
 1   4
  \
   2

Output: 1

Example 2:
Input: root = [5,3,6,2,4,null,null,1], k = 3

1
2
3
4
5
6
7
       5
      / \
     3   6
    / \
   2   4
  /
 1
       5
      / \
     3   6
    / \
   2   4
  /
 1

Output: 3

Your task, in short, is to find k-th smallest number from a Binary Search Tree (BST) where the left nodes are smaller than parent, and the right nodes are bigger than parent.

Compute Kth Smallest Element via DFS In-order Traversal (Recursion)

The in-order traversal of a BST will give you a sorted list of nodes. And you can do this easily using recursion. Let’s using a list to store the numbers when doing the DFS (Depth First Search).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        vector<int> list;
        dfs(root, list);
        return list[k - 1];
    }
    
private:
    void dfs(TreeNode* root, vector<int> &list) {
        if (root == NULL) return;
        dfs(root->left, list);
        list.push_back(root->val);
        dfs(root->right, list);
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode* root, int k) {
        vector<int> list;
        dfs(root, list);
        return list[k - 1];
    }
    
private:
    void dfs(TreeNode* root, vector<int> &list) {
        if (root == NULL) return;
        dfs(root->left, list);
        list.push_back(root->val);
        dfs(root->right, list);
    }
};

The equivalent Java implementation is:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        List<Integer> list = new ArrayList<>();
        dfs(root, list);
        return list.get(k - 1);
    }
    
    private void dfs(TreeNode root, List<Integer> list) {
        if (root == null) return;
        dfs(root.left, list);
        list.add(root.val);
        dfs(root.right, list);
    }
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        List<Integer> list = new ArrayList<>();
        dfs(root, list);
        return list.get(k - 1);
    }
    
    private void dfs(TreeNode root, List<Integer> list) {
        if (root == null) return;
        dfs(root.left, list);
        list.add(root.val);
        dfs(root.right, list);
    }
}

The space complexity is O(N) and the runtime complexity is also O(N) as each node is visited exactly once.

Kth Smallest Element via Iterative DFS In-order traversal

We can manage the stack by ourself and this leads to the iterative approach. Also, actually we don’t need to store/copy the elements when the traversal is visiting each node, as we are only interested in the k-th smallest element – thus we decrement the counter until it reaches zero – then we have the right number.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
    public:
        int kthSmallest(TreeNode *root, int k) {
            stack<TreeNode*> stack;
            while (root || !stack.empty()) {
                while (root) {
                    stack.push(root);
                    root = root->left;
                }
                root = stack.top();
                stack.pop();
                if (--k == 0) return root->val;
                root = root->right;
            }
        }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
    public:
        int kthSmallest(TreeNode *root, int k) {
            stack<TreeNode*> stack;
            while (root || !stack.empty()) {
                while (root) {
                    stack.push(root);
                    root = root->left;
                }
                root = stack.top();
                stack.pop();
                if (--k == 0) return root->val;
                root = root->right;
            }
        }
};

The corresponding Java implementation is similar as follows:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        Stack<TreeNode> stack = new Stack<>();
        while (root != null || !stack.isEmpty()) {
            while (root != null) {
                stack.push(root);
                root = root.left;
            }
            root = stack.pop();
            if (--k == 0) return root.val;
            root = root.right;
        }
        return -1;
    }    
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        Stack<TreeNode> stack = new Stack<>();
        while (root != null || !stack.isEmpty()) {
            while (root != null) {
                stack.push(root);
                root = root.left;
            }
            root = stack.pop();
            if (--k == 0) return root.val;
            root = root.right;
        }
        return -1;
    }    
}

The space complexity is O(N), as if the BST may degraded into a linked-list, which also gives a worst case O(N) runtime complexity. We can also say the time complexity is O(H+K) and H (height of the binary search tree) is N in the worst case or average LogN if binary search tree is balanced.

Binary Search using Recursion

Let’s define a count functoin to count the number of nodes in the tree – as we know the BST in order is sorted, we can use binary search to locate the K-th smallest number by choosing either left, right sub trees or the current node.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode *root, int k) {
        int cnt = count(root->left);
        if (cnt + 1 == k) return root->val;
        // on the left sub tree
        if (k <= cnt) {
            return kthSmallest(root->left, k);
        }
        // on the left right tree
        return kthSmallest(root->right, k - 1 - cnt);
    }
    
private:
    // count the number of nodes in the tree
    int count(TreeNode *root) {
        if (root == NULL) return 0;
        return count(root->left) + count(root->right) + 1;
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode *root, int k) {
        int cnt = count(root->left);
        if (cnt + 1 == k) return root->val;
        // on the left sub tree
        if (k <= cnt) {
            return kthSmallest(root->left, k);
        }
        // on the left right tree
        return kthSmallest(root->right, k - 1 - cnt);
    }
    
private:
    // count the number of nodes in the tree
    int count(TreeNode *root) {
        if (root == NULL) return 0;
        return count(root->left) + count(root->right) + 1;
    }
};

The corresponding Java solution is:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        int cnt = count(root.left);
        if (cnt + 1 == k) return root.val;
        // on the left sub tree
        if (k <= cnt) return kthSmallest(root.left, k);
        // on the right sub tree
        return kthSmallest(root.right, k - cnt - 1);
    }
    
    // count the number of nodes in the tree
    private int count(TreeNode root) {
        if (root == null) return 0;
        return 1 + count(root.left) + count(root.right);
    }
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        int cnt = count(root.left);
        if (cnt + 1 == k) return root.val;
        // on the left sub tree
        if (k <= cnt) return kthSmallest(root.left, k);
        // on the right sub tree
        return kthSmallest(root.right, k - cnt - 1);
    }
    
    // count the number of nodes in the tree
    private int count(TreeNode root) {
        if (root == null) return 0;
        return 1 + count(root.left) + count(root.right);
    }
}

The runtime complexity is O(N) given the BST tree may be degraded into a linked-list where there are no right subtrees at all. On a highly balanced BST, the complexity is average O(logN) where N is the total number of nodes in the BST.

Using Priority Queue

We might use whatever order to traverse the BST and push the nodes into a priority queue. Then we can pop the K-1 elements from the queue and the next one is the K-th smallest element. Be aware that we have to make sure the pop-out order is from smallest to largest as in C++, the priority queue pops out by default the largest element.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode *root, int k) {
        priority_queue<int, vector<int>, std::greater<int>> q;
        dfs(root, q);
        for (int i = 0; i < k - 1; ++ i) {
            q.pop();
        }
        return q.top();
    }
    
private:
    void dfs(TreeNode *root, priority_queue<int, vector<int>, std::greater<int>> &q) {
        if (root == NULL) return;
        q.push(root->val);
        dfs(root->left, q);
        dfs(root->right, q);
    }
};
/**
 * Definition for a binary tree node.
 * struct TreeNode {
 *     int val;
 *     TreeNode *left;
 *     TreeNode *right;
 *     TreeNode(int x) : val(x), left(NULL), right(NULL) {}
 * };
 */
class Solution {
public:
    int kthSmallest(TreeNode *root, int k) {
        priority_queue<int, vector<int>, std::greater<int>> q;
        dfs(root, q);
        for (int i = 0; i < k - 1; ++ i) {
            q.pop();
        }
        return q.top();
    }
    
private:
    void dfs(TreeNode *root, priority_queue<int, vector<int>, std::greater<int>> &q) {
        if (root == NULL) return;
        q.push(root->val);
        dfs(root->left, q);
        dfs(root->right, q);
    }
};

The std::greater changes reverses the order for the priority queue, and you might want to typedef as currently the type of priority with templates/generics is quite verbose.

In Java, it is a bit similar:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        PriorityQueue<Integer> q = new PriorityQueue<>();
        dfs(root, q);
        for (int i = 0; i < k - 1; ++ i) {
            q.poll();
        }
        return q.peek();
    }
    
    private void dfs(TreeNode root, PriorityQueue<Integer> q) {
        if (root == null) return;
        dfs(root.right, q);
        q.add(root.val);
        dfs(root.left, q);
    }
}
/**
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode(int x) { val = x; }
 * }
 */
class Solution {
    public int kthSmallest(TreeNode root, int k) {
        PriorityQueue<Integer> q = new PriorityQueue<>();
        dfs(root, q);
        for (int i = 0; i < k - 1; ++ i) {
            q.poll();
        }
        return q.peek();
    }
    
    private void dfs(TreeNode root, PriorityQueue<Integer> q) {
        if (root == null) return;
        dfs(root.right, q);
        q.add(root.val);
        dfs(root.left, q);
    }
}

The time complexity is O(N) – as each node is visited at least once – at most twice. See also: Teaching Kids Programming – Kth Smallest Element in a BST via Iterative Inorder Traversal Algorithm

k-th Smallest Element in the Binary Search Tree

–EOF (The Ultimate Computing & Technology Blog) —

GD Star Rating
loading...
1566 words
Last Post: Fixing Profile Query Command due to API Change in Steem Blockchain
Next Post: How to Mirror a Binary Tree?

The Permanent URL is: How to Find the Kth Smallest Element in a BST Tree Using Java/C++?

Leave a Reply