C++ Implementation of Segment Tree


Segment Tree is one of the most important data structure in Computer Science. Similar to Binary Index Tree, a Segment Tree allows us to update and query (range) in O(logN) and O(logN + K) where K is the number of segments. Unlike the O(nlogN) for Binary Index Tree to build, a Segment Tree only needs O(N) time to build.

segment-tree C++ Implementation of Segment Tree algorithms c / c++ data structure

segment-tree

Segment Tree is essentially a balanced binary tree. The height of a Segment Tree for n elements is O(logN).

Definition of a Segment Tree in C++

We can use a struct (or a class) in C++ to define a segment tree, which is essentially a binary tree with left and right pointer to sub trees. The Segment Tree node should have a from and to (start/finish) that defines the range. A Segment Tree node can be customized to record the sum in the given range or storing the maximum/minimum etc.

1
2
3
4
5
6
7
8
9
10
struct SegmentTree {  
    int sum;
    int from;
    int to;
    SegmentTree *left;
    SegmentTree *right;
    // optional constructor
    SegmentTree(int from, int to, int sum, SegmentTree* left=NULL,SegmentTree* right=NULL): 
        from(from),to(to),sum(sum),left(left),right(right) {}
};
struct SegmentTree {  
    int sum;
    int from;
    int to;
    SegmentTree *left;
    SegmentTree *right;
    // optional constructor
    SegmentTree(int from, int to, int sum, SegmentTree* left=NULL,SegmentTree* right=NULL): 
        from(from),to(to),sum(sum),left(left),right(right) {}
};

Recursive Algorithm to Build a Segment Tree in C++

Given a range, we can divide the range into two halves. Then recursively we can build its left and right segment subtrees. The terminal condition is that when the given range is just a point (start is equal to finish), then we construct the leaf node of the segment tree.

1
2
3
4
5
6
7
8
9
SegmentTree* buildTree(vector<int>& nums, int from, int to) {
    if (from == to) {
        return new SegmentTree(from, to, nums[from]);
    }
    int mid = from + (to - from) / 2;
    auto left = buildTree(nums, from, mid);
    auto right = buildTree(nums, mid + 1, to);
    return new SegmentTree(from, to, left->sum + right->sum, left, right);
}
SegmentTree* buildTree(vector<int>& nums, int from, int to) {
    if (from == to) {
        return new SegmentTree(from, to, nums[from]);
    }
    int mid = from + (to - from) / 2;
    auto left = buildTree(nums, from, mid);
    auto right = buildTree(nums, mid + 1, to);
    return new SegmentTree(from, to, left->sum + right->sum, left, right);
}

The complexity is O(N) to build a Segment Tree. T(N) = 2*T(N/2) = O(N) for N segments.

Range Query using a Segment Tree

Given a range, we can query its sum (or max/min value) using a Segment Tree in O(logN + K) time where N is the size of the largest segments and K is the number of segments involved.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
int queryTree(SegmentTree* root, int from, int to) {
    if (from == root->from && to == root->to) {
        return root->sum;
    }
    int mid = root->from + (root->to - root->from) / 2;
    if (to <= mid) {
        return queryTree(root->left, from, to);
    } 
    if (from > mid) {
        return queryTree(root->right, from, to);
    }
    return queryTree(root->left, from, mid) +
        queryTree(root->right, mid + 1, to);
}
int queryTree(SegmentTree* root, int from, int to) {
    if (from == root->from && to == root->to) {
        return root->sum;
    }
    int mid = root->from + (root->to - root->from) / 2;
    if (to <= mid) {
        return queryTree(root->left, from, to);
    } 
    if (from > mid) {
        return queryTree(root->right, from, to);
    }
    return queryTree(root->left, from, mid) +
        queryTree(root->right, mid + 1, to);
}

If the from is equal to do (reaching the leaf node of the segment tree), we simply return the value. Then we need to check if the given range exists entirely in the left segment tree or right segment tree, or in other cases, crossing two segments – which we need to return the combined value of both recursive calls.

Updating the value in Segment Tree in C++

Given a index and value, we can recursive find the segment that it belongs to in the segment tree and then update its value in the leaf node, but don’t forget to recursively update the parent’s sum/min/max value bottom up.

1
2
3
4
5
6
7
8
9
10
11
12
13
void updateTree(SegmentTree* root, int index, int value) {
    if ((root->from == index) && (root->to == index)) {
        root->sum = value;
        return;
    }
    int mid = root->from + (root->to - root->from) / 2;
    if (index <= mid) {
        updateTree(root->left, index, value);
    } else {
        updateTree(root->right, index, value);
    }
    root->sum = root->left->sum + root->right->sum;
}
void updateTree(SegmentTree* root, int index, int value) {
    if ((root->from == index) && (root->to == index)) {
        root->sum = value;
        return;
    }
    int mid = root->from + (root->to - root->from) / 2;
    if (index <= mid) {
        updateTree(root->left, index, value);
    } else {
        updateTree(root->right, index, value);
    }
    root->sum = root->left->sum + root->right->sum;
}

Updating the value in the segment tree needs O(logN) time.

–EOF (The Ultimate Computing & Technology Blog) —

GD Star Rating
loading...
787 words
Last Post: Sexy One-liner of Python to Solve the FizzBuzz
Next Post: Algorithm to Sort the Columns of a Matrix using Transpose

The Permanent URL is: C++ Implementation of Segment Tree

Leave a Reply