The Union Find (Disjoint Set) Implementation in Java/C++


2091E060-120E-4C44-BAA3-7E4E0DF7BD55 The Union Find (Disjoint Set) Implementation in Java/C++ algorithms c / c++ implementation java

Java

The Union-Find (Disjoint Set) is a commonly-used algorithm that can solve e.g. Minimal Spanning Tree. The following is a Java implementation of a Union-Find Class.

disjoint1 The Union Find (Disjoint Set) Implementation in Java/C++ algorithms c / c++ implementation java

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package com.helloacm;
 
public class UnionFind {
    private int[] parent;
    public UnionFind(int n) {
        parent = new int[n];
        for (var i = 0; i < n; i++) {
            parent[i] = i;
        }
    }
 
    public int Find(int x) {
        if (x == parent[x]) {
            return x;
        }
        // compress the paths
        return parent[x] = Find(parent[x]);
    }
 
    public void Union(int x, int y)  {
        var px = Find(x);
        var py = Find(y);
        if (px != py) {
            parent[px] = py;
        }
    }
 
    public int size() { // number of groups
        int ans = 0;
        for (int i = 0; i < parent.length(); ++ i) {
            if (i == parent[i]) ans ++;
        }
        return ans;
    }  
}
package com.helloacm;

public class UnionFind {
    private int[] parent;
    public UnionFind(int n) {
        parent = new int[n];
        for (var i = 0; i < n; i++) {
            parent[i] = i;
        }
    }

    public int Find(int x) {
        if (x == parent[x]) {
            return x;
        }
        // compress the paths
        return parent[x] = Find(parent[x]);
    }

    public void Union(int x, int y)  {
        var px = Find(x);
        var py = Find(y);
        if (px != py) {
            parent[px] = py;
        }
    }

    public int size() { // number of groups
        int ans = 0;
        for (int i = 0; i < parent.length(); ++ i) {
            if (i == parent[i]) ans ++;
        }
        return ans;
    }  
}

The above algorithm uses O(N) space and requires O(N) time. Example usage:

1
2
3
4
5
6
7
8
9
10
package com.helloacm;
 
public class Main {
    public static void main(String[] args) {
        var uf = new UnionFind(5);
        System.out.println(uf.Find(3));
        uf.Union(3, 4);
        System.out.println(uf.Find(3)); // after join, 3's parent is 4.
    }
}
package com.helloacm;

public class Main {
    public static void main(String[] args) {
        var uf = new UnionFind(5);
        System.out.println(uf.Find(3));
        uf.Union(3, 4);
        System.out.println(uf.Find(3)); // after join, 3's parent is 4.
    }
}

This Java code prints 3 and 4.

C++ Disjoint Set / Union Find Algorithm Implementation

Similar, here is the C++ implementation of the Disjoint Set data structure. The union is a keyword in C++ and therefore we implement Union method instead:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class UF {
    public:
        UF(int N) {
            G.resize(N);
            std::iota(begin(G), end(G), 0);
        }
    
        int Find(int x) {
            if (x == G[x]) {
                return x;
            }
            return G[x] = Find(G[x]);
        }
    
        void Union(int x, int y) {
            int px = Find(x);
            int py = Find(y);
            if (px != py) {
                G[px] = py;    
            }            
        }    
 
        int size() {
            int ans = 0;
            for (int i = 0; i < G.size(); ++ i) {
                if (i == G[i]) ans ++;
            }
            return ans;
        }  
    private:
        vector<int> G;
};
class UF {
    public:
        UF(int N) {
            G.resize(N);
            std::iota(begin(G), end(G), 0);
        }
    
        int Find(int x) {
            if (x == G[x]) {
                return x;
            }
            return G[x] = Find(G[x]);
        }
    
        void Union(int x, int y) {
            int px = Find(x);
            int py = Find(y);
            if (px != py) {
                G[px] = py;    
            }            
        }    

        int size() {
            int ans = 0;
            for (int i = 0; i < G.size(); ++ i) {
                if (i == G[i]) ans ++;
            }
            return ans;
        }  
    private:
        vector<int> G;
};

Here, we use the iota from STL to easily assign incrementing values to the initial Group vector:

1
2
// G = {0, 1, 2, ...};
std::iota(begin(G), end(G), 0);
// G = {0, 1, 2, ...};
std::iota(begin(G), end(G), 0);

Compress Paths and Union Rules for Disjoint Set

As shown above - when in Find - we can compress the paths. Also, in the Union, we can either set G[px] = py or G[py] = px.

Choose a smaller group ID

This would be easiest - we compare the px and py value before setting the group:

1
2
3
4
5
6
7
8
void Union(int x, int y) {
  int px = Find(x);
  int py = Find(y);
  if (px != py) {
    if (px < py) swap(px, py); // make py smaller
    G[px] = py;    
  }            
} 
void Union(int x, int y) {
  int px = Find(x);
  int py = Find(y);
  if (px != py) {
    if (px < py) swap(px, py); // make py smaller
    G[px] = py;    
  }            
} 

Merging into Smaller Size

Alternatively, we can allocate an addition array to store the sizes for each group and always merge the larger group into the smaller one:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class UF {
    public:
        UF(int N) {
            G.resize(N);
            std::iota(begin(G), end(G), 0);
            sizes.resize(N);
            std::fill(begin(sizes), end(sizes), 1);
        }
    
        int Find(int x) {
            if (x == G[x]) {
                return x;
            }
            return G[x] = Find(G[x]);
        }
    
        void Union(int x, int y) {
            int px = Find(x);
            int py = Find(y);
            if (px != py) {
                if (sizes[px] < sizes[py]) swap(px, py);
                G[px] = py;    
                sizes[py] += sizes[px];
            }            
        }    
 
        int size() {
            int ans = 0;
            for (int i = 0; i < G.size(); ++ i) {
                if (i == G[i]) ans ++;
            }
            return ans;
        }  
    private:
        vector<int> G;
        vector<int> sizes;
};
class UF {
    public:
        UF(int N) {
            G.resize(N);
            std::iota(begin(G), end(G), 0);
            sizes.resize(N);
            std::fill(begin(sizes), end(sizes), 1);
        }
    
        int Find(int x) {
            if (x == G[x]) {
                return x;
            }
            return G[x] = Find(G[x]);
        }
    
        void Union(int x, int y) {
            int px = Find(x);
            int py = Find(y);
            if (px != py) {
                if (sizes[px] < sizes[py]) swap(px, py);
                G[px] = py;    
                sizes[py] += sizes[px];
            }            
        }    

        int size() {
            int ans = 0;
            for (int i = 0; i < G.size(); ++ i) {
                if (i == G[i]) ans ++;
            }
            return ans;
        }  
    private:
        vector<int> G;
        vector<int> sizes;
};

One complete C++ Disjoint Set Class that compresses the Path in Find and merges the branch with smaller rank into a larger one:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
class DSU {
public:
    DSU(int n) {
        data.resize(n);
        iota(begin(data), end(data), 0);
        rank.resize(n, 0);
    }
    
    int find(int x) {
        if (x != data[x]) {
            data[x] = find(data[x]);
        }
        return data[x];
    }
    
    bool join(int x, int y) {
        int px = find(x);
        int py = find(y);
        if (px == py) return false;
        if (rank[px] > rank[py]) {
            data[py] = px;
        } else if (rank[py] > rank[px]) {
            data[px] = py;
        } else {
            // merge smaller trank into large one
            data[px] = py;
            // updating the rank for the large group
            ++ rank[py]; 
        }
        return true;
    }
    
    int size() {
        int ans = 0;
        for (int i = 0; i < data.size(); ++ i) {
            if (i == data[i]) ans ++;
        }
        return ans;
    }  
private:
    vector<int> data;
    vector<int> rank;
};
class DSU {
public:
    DSU(int n) {
        data.resize(n);
        iota(begin(data), end(data), 0);
        rank.resize(n, 0);
    }
    
    int find(int x) {
        if (x != data[x]) {
            data[x] = find(data[x]);
        }
        return data[x];
    }
    
    bool join(int x, int y) {
        int px = find(x);
        int py = find(y);
        if (px == py) return false;
        if (rank[px] > rank[py]) {
            data[py] = px;
        } else if (rank[py] > rank[px]) {
            data[px] = py;
        } else {
            // merge smaller trank into large one
            data[px] = py;
            // updating the rank for the large group
            ++ rank[py]; 
        }
        return true;
    }
    
    int size() {
        int ans = 0;
        for (int i = 0; i < data.size(); ++ i) {
            if (i == data[i]) ans ++;
        }
        return ans;
    }  
private:
    vector<int> data;
    vector<int> rank;
};

Number of Connected Groups for a Disjoint Set

To get the number of connected groups, we have to go through the parent/group ID, and count those if IDs are the same as index (the initialised values)

1
2
3
4
5
6
7
    int size() {
        int ans = 0;
        for (int i = 0; i < data.size(); ++ i) {
            if (i == data[i]) ans ++;
        }
        return ans;
    }  
    int size() {
        int ans = 0;
        for (int i = 0; i < data.size(); ++ i) {
            if (i == data[i]) ans ++;
        }
        return ans;
    }  

Alternatively, we can use a variable to count the number of connected Groups. At the begining, it will be default to N, and when we join two components, we decrement it. So it will be something like this:

1
2
3
4
5
6
7
8
9
10
11
private int numberOfGroups = n;
bool join(int x, int y) {
  int px = find(x);
  int py = find(y);
  if (px == py) return false;
  numberOfGroups --;
  // now merge x and y
  // ....
  // ... see above code
  return true;
}
private int numberOfGroups = n;
bool join(int x, int y) {
  int px = find(x);
  int py = find(y);
  if (px == py) return false;
  numberOfGroups --;
  // now merge x and y
  // ....
  // ... see above code
  return true;
}

For the Python Class of Disjoint Set: Algorithm to Remove a Redundant Connection from a Undirected Graph to Make a Valid Tree using Union-Find (Disjoint Set)

--EOF (The Ultimate Computing & Technology Blog) --

GD Star Rating
loading...
1042 words
Last Post: How to use the Leetcode's Mock Interview Overview to Nail Your Interview?
Next Post: How to Remove Duplicate Elements from Vector in C++ using std::unique?

The Permanent URL is: The Union Find (Disjoint Set) Implementation in Java/C++

Leave a Reply