数据结构 | 归并树
归并树 (Merge Sort Tree): 归并树是线段树和归并排序的合成,它利用线段树将归并排序的每一步都记录下来。
- 查找区间
内的大小范围在 的数的个数(类似条件均可查找) - 查找区间
内第 大的数
归并树的思想基于线段树,因此需要先学习线段树:https://io.zouht.com/117.html
思路
归并排序我们不陌生,核心思想就是用递归完成拆分和合并。观察归并的拆分方式,我们可以发现,它和线段树完全一样。
下图是线段树的分段示意图,如果我们把它看作归并排序的拆分过程,可以发现其实是一模一样的。
线段树中,每个节点储存的是一个数值,这个数值维护着这一段的信息。
而归并树中,每个节点储存的是一个有序数列,这个数列就是归并到该节点时的有序数列状态。其实就是利用线段树将归并排序的每一步都记录下来。
应用
查找区间 内的 的数的个数
思想和线段树的区间查询完全一致:
由长分解到短,对于长度为
- 如果线段树内的区间
完全被 包含,在该节点的有序数列中二分找到 的数的个数,加入答案。 否则,则考察它的左右子区间,令分界点
:如果
的左子区间 与 有交集( )- 递归考察
区间
- 递归考察
如果
的右子区间 与 有交集( )- 递归考察
区间
- 递归考察
- 如果上面的情况都不满足,则说明
区间与我们要求的 区间完全不相交,直接跳过。
对于找
查找区间 内第 大的数
我们在最终的有序数列
如果
则执行:- 中点
如果查询
中 的数的个数- 令
- 令
如果查询
中 的数的个数- 令
- 令
- 中点
- 最终,第
大的数便是 .
其中,最终的有序数列
代码
- 建树:
- 查范围数量:
- 查第
小:
struct MergeSortTree
{
/* ### array index must start from ONE ### */
int n;
vector<vector<int>> tree;
// arr: ori arr, [s, t]: cur seg, x: cur node
void build(const vector<int> &arr, int s, int t, int x)
{
if (s == t)
{
tree[x] = {arr[s]};
return;
}
int m = (s + t) / 2;
build(arr, s, m, 2 * x);
build(arr, m + 1, t, 2 * x + 1);
merge(tree[2 * x].begin(), tree[2 * x].end(),
tree[2 * x + 1].begin(), tree[2 * x + 1].end(),
back_inserter(tree[x]));
}
MergeSortTree(const vector<int> &arr) : n(arr.size())
{
int sz = 1 << (__lg(n) + bool(__builtin_popcount(n) - 1)); // sz = \lceil \log_{2}{n} \rceil
tree.resize(2 * sz);
build(arr, 1, n, 1);
}
// [l, r]: query array interval, [mn, mx]: query value interval, [s, t]: cur seg, x: cur node
int count(int l, int r, int mn, int mx, int s, int t, int x)
{
if (l <= s && t <= r)
return upper_bound(tree[x].begin(), tree[x].end(), mx) - lower_bound(tree[x].begin(), tree[x].end(), mn);
int m = (s + t) / 2, ans = 0;
if (l <= m)
ans += count(l, r, mn, mx, s, m, x * 2);
if (r > m)
ans += count(l, r, mn, mx, m + 1, t, x * 2 + 1);
return ans;
}
// query number of elements in the [l, r] interval that fall within the range [mn, mx]
int count(int l, int r, int mn, int mx)
{
return count(l, r, mn, mx, 1, n, 1);
}
// find the kth smallest number in the [l, r] interval
int count(int l, int r, int k)
{
int pl = 1, pr = n;
while (pl < pr)
{
int mid = (pl + pr) / 2;
if (count(l, r, INT32_MIN, tree[1][mid]) < k)
pl = mid + 1;
else
pr = mid;
}
return tree[1][pl];
}
};
模板题
Luogu P3834: https://www.luogu.com.cn/problem/P3834
(该题应当用复杂度更优的可持久化线段树来完成,用本文的归并树应当只能 AC 五个点)
本文采用 CC BY-SA 4.0 许可,本文 Markdown 源码:Haotian-BiJi