CommonLounge Archive

Merge Sort Trees

April 22, 2017

Prerequisites: Segment Trees and Merge-Sort.

Motivation: Given an array of N integers. You have to answer some queries in form (l, r, k). To answer this query you have to print the number of integers less than k in the sub-array array[l … r].

You should be able to solve the problem in O(log^2 n) per query at least. How? You can use Merge Sort Trees.

Now what is a Merge Sort Tree? Merge Sort Tree is actually a Segment Tree but each node contains a vector. If the range of the node is [l,r] then the vector will contain the elements of array[l…r] but in sorted order.

To solve the above problem we can make a Merge Sort Tree first and go into relevant segments and binary search on the vector that was stored in those nodes to count how many numbers are less than k. This will give O(log^2 n).

Building the Merge Sort Tree

Now, how do we implement Merge Sort Tree? First try yourself :)

One way is- you go into each node of the segment tree, push the elements that was supposed to be there in the vector, then sort them. This will not give any efficient complexity!

Now, a observation, actually the vector in a node contains the elements from it’s left child and right child!! And assume we have already build the vector of left child and tight child. They are already sorted!

For a parent node you can just merge the left and right child’s vector. The merging step can be done in O(n) using two pointer. Same as the merge step of merge sort.

So only in the leaf nodes push the array element and the then recursively build the whole merge sort tree. :)

This will give you O(n log n) complexity to build the merge sort tree :)

Query in Merge Sort Tree

Now the query part. It is also easy. You just need to go to each relevant segment and binary search on them to find how many elements are less than k in this range! Then add them recursively :) DONE!!! It will have complexity O(log^2 n)!

Implementation Details:

Build Function: Build Function will look something like this -

const int maxn = 1e5+10;
vector<int> tree[4*maxn];
int n, m, arr[maxn]; 
#define all(v) v.begin(), v.end()
void build(int node, int l, int r) {
  if(l == r) {
    tree[node].push_back(arr[l]);
    return;
  } 
  int mid = l + r >> 1, 
  left = node << 1, right = left|1;
  build(left, l, mid);
  build(right, mid+1, r);
  merge(all(tree[left]), all(tree[right]), 
          back_inserter(tree[node]));
}

When l=r then it is a leaf node, so we just push it into that node’s vector. For other nodes, we recursively build it’s left and right child. Then we merge them. The last line merges left and right child’s vector in O(n) time. You can do this with two pointer too.

[Note: The backinserter(v) returns a A std::backinsert_inserter which can be used to add elements to the end of the container v. ]

Query: Query part is quite easy, here is a simple implementation-

int query(int node, int l, int r, int i, int j, int k) {
  if(i > r || l > j) return 0;
  if(i <= l && r <= j) {
    return lower_bound(all(tree[node]), k) 
               - tree[node].begin();
  } 
  int mid = l + r >> 1, 
    left = node << 1, right = left|1;
  return query(left, l, mid, i, j, k) + 
       query(right, mid+1, r, i, j, k);
}

I am going to a relevant node and searching in that nodes range, how many numbers are less than k.

Here is a practice problem -

Problem MKTHNUM

I think after these explanations it will be very easy to solve :) First try for at least 30 minutes. It you stuck then Hints are in replies :)

If anyone have some more problems that can be solved using this technique please leave a reply :)

Update in Merge Sort Trees

We can do point updates in Merge Sort Trees! We will need STL:policy based data structure. We can go to each relevant node, erase element of that index and insert another!

We will still need lowerbound() or upperbound(), so we can’t use set here! That’s why we need policy based DS.

Then build complexity will be O(n log^2 n) and updating will be O(log^2 n). As there can be O(log n) relevant nodes and updating each of then will need O(log n).

If you don’t know policy based ds you can read here - C++ STL: Policy based data structures

You can simply use this is the elements of the array are distinct! But when the are not distinct you just need to use pair<int,int>, and store an element like {value, index}!

I’ll left the implementation on you :) You can try this problem with merge sort tree with node updates - Codeforces Round #404 Div2E


© 2016-2022. All rights reserved.