可持久化数据结构:总是可以保留每一个历史版本的数据结构。

可持久化线段树:可以保存每一次操作的历史版本的线段树

可持久化权值线段树 (主席树):可以保存每一次操作的历史版本的权值线段树

前置内容:线段树 https://io.zouht.com/117.html

1 可持久化线段树

1.1 问题引入

我们都知道线段树可以维护序列的区间信息,支持区间修改、区间查询,但我们现在要求保留历史版本。

对于一个序列 $a_1,a_2,\dots,a_n$,给出操作版本号 $x$ 和区间 $[l,r]$,可以进行修改和查询操作:

  • 对于第 $x$ 个版本,修改 $a_l,a_{l+1},\dots,a_r$ 的值,作为一个新版本。
  • 对于第 $x$ 个版本,查询 $a_l,a_{l+1},\dots,a_r$ 的和。

1.1 原理

既然持久化的目标是保留每一个历史版本,那最朴素的思想就是每次操作都开一棵新的线段树。但是不用算都知道,这样操作的话空间复杂度绝对是不可接受的。

我们先观察一下线段树的一次操作造成的影响,以下图这棵树为例,节点上的数字代表 ID,节点具体维护的值我们暂不考虑:

如果我们要对 $5$ 号节点进行更新,那么根据线段树的性质,会影响的节点仅仅是 $1,2,5$,其他节点的值不会被影响:

推广来说,对一个节点进行修改,会影响的节点仅仅是从根节点到该节点的一条链,也就是会影响 $\log_2 n$ 个节点。我们最开始的朴素思想是把整个线段树复制一遍存下来,这就会导致很多空间被浪费,因为除了那条链上的节点,其他节点其实是完全没变的。

那么,我们就需要找到一个策略来节省空间,仅保存被更新的节点。


这个策略便是我们不实际修改线段树的节点,所有更新都以额外加点的形式来实现。

如果一次修改造成了一条链的更新,那么我们就额外新增这条链修改后的节点,对于没变的节点,指针直接指向原版本的节点即可。通俗来说,就是我们的更新就像打补丁一样,旧的仍然在哪里,被修改的内容就是那个补丁。

需要注意的是,这样做的话就不能用堆式存储法(对于节点 $x$,左儿子为 $2x$,右儿子为 $2x+1$),我们得动态开点,并保存左右儿子的编号,相当于左儿子指针和右儿子指针。

另外,每次更新会产生一个新的根节点,我们需要一个 root 数组来保存每个版本的根节点编号。

例如,对于上文的示例,我们进行连续的两个操作:

  • 版本 $0$ $\to$ 修改节点 $5$ 的值 $\to$ 版本 $1$
  • 版本 $1$ $\to$ 修改节点 $3$ 的值 $\to$ 版本 $2$

那么,可持久化线段树形成的结构将会是下图:

1.2 实现

原理了解了,接下来的问题就是如何用代码实现了。

首先需要注意的是节点空间要多开点,一般可以开题目规模的 $32$ 倍(MAXN << 5)。因为每次修改将会产生 $\log n$ 个新点,如果进行 $q$ 次修改操作,最坏情况下需要的空间约为 $2n+q\log n$. 如果 $n,q$ 均为 $10^6$,则需要的空间约为 $21$ 倍的 $n$.

其次就是不能用堆式存储法直接计算编号,而是需要用 $lson,rson$ 两个数组储存左右节点的编号。对于节点 $i$,它的左儿子为 $lson_i$,右儿子为 $rson_i$.

具体实现就直接参考下文代码即可,如果学习过线段树应该比较好理解,因为其实核心思想是一样的,可持久化只是额外的功能。

该模板仅支持单点修改,下面我们将会讲解区间修改。

对应题目:https://www.luogu.com.cn/problem/P3919

namespace pst
{
    /* ### array index must start from ONE ### */
    constexpr int MAXN = 1e6;
    int n;
    int root[MAXN];
    int val[(MAXN << 5) + 10], lson[(MAXN << 5) + 10], rson[(MAXN << 5) + 10], cur_idx = 0;

    int build(const vector<int> &arr, int s, int t)
    {
        int now = ++cur_idx;
        if (s == t)
        {
            val[now] = arr[s];
            return now;
        }
        int m = (s + t) / 2;
        lson[now] = build(arr, s, m);
        rson[now] = build(arr, m + 1, t);
        return now;
    }

    int clone_node(int orig)
    {
        ++cur_idx;
        val[cur_idx] = val[orig];
        lson[cur_idx] = lson[orig];
        rson[cur_idx] = rson[orig];
        return cur_idx;
    }

    int update(int x, int c, int s, int t, int p)
    {
        int now = clone_node(p);
        if (s == t)
        {
            val[now] = c;
            return now;
        }
        int m = (s + t) / 2;
        if (x <= m)
            lson[now] = update(x, c, s, m, lson[now]);
        else
            rson[now] = update(x, c, m + 1, t, rson[now]);
        return now;
    }

    int query(int x, int s, int t, int p)
    {
        if (s == t)
            return val[p];
        int m = (s + t) / 2;
        if (x <= m)
            return query(x, s, m, lson[p]);
        else
            return query(x, m + 1, t, rson[p]);
    }
};

1.3 区间修改与标记永久化

对于普通线段树,我们采用的区间修改策略是懒惰标记法,需要把标记上下传也就是 push_downpush_up. 但在可持久化线段树中,如果也这么做那将会导致严重的问题。

考虑上面的例子:

如果我们有一个标记在 $9$ 节点,如果对标记进行 push_down 下传,可以发现它会被传到 $7$ 节点,甚至再往下会传到 $4$ 节点,那这样的话对版本 $2$ 的修改就把版本 $1$ 和 $0$ 全破坏了。要避免这个问题只能把受影响节点全部新建,会浪费许多空间。

因此,使用懒惰标记上下传的方式无法解决这个问题,我们用懒惰标记永久化来解决。

例如上面这个线段树,储存了 $9$ 个元素 $1,2,\dots,9$,我们对 $[4,6]$ 这个区间的元素统一 $+1$,其中绿色和蓝色代表会更新的节点,也就是新版本会额外创建的新节点。

对于这些节点,我们会将它实际更新,更新方式就是计算区间交叉长度来获取增长值。其中,蓝色的节点比较特殊,它们会被标上懒惰标记。这些步骤和标记上下传是一致的。

不一样的地方就是,节点拿到标记后不再下传。也就是即使之后询问到了标蓝的节点,它们也不会将标记 push_down 下去。那该怎么计算询问结果呢?答案是在 query 函数里传入节点的标记值。

例如更新后,我们对 $[3,4]$ 这个区间进行询问,对于 $4$ 号位置的节点,它实际上是被上一次 $+1$ 了,但是懒惰标记在它的上一层且不会传给它。此时如果请求来到拿着懒惰标记的节点,它就会将标记值传入向下一层的 query 函数,相当于告诉下一层它身上现在有多少个懒惰标记,但是不真正把标记传给它。

下一层的节点受到 query 请求后,就会计算得到结果,同时将会把传给它的标记值根据区间长度加入结果再返回给上层,完成计算。

文字表述比较抽象,建议对照代码:

namespace pst
{
    /* ### array index must start from ONE ### */
    constexpr int MAXN = 1e6;
    int n;
    signed root[MAXN];
    int cur_idx = 0;
    int val[(MAXN << 5) + 10], mark[(MAXN << 5) + 10];
    signed lson[(MAXN << 5) + 10], rson[(MAXN << 5) + 10];

    int build(const vector<int> &arr, int s, int t)
    {
        int now = ++cur_idx;
        if (s == t)
        {
            val[now] = arr[s];
            return now;
        }
        int m = (s + t) / 2;
        lson[now] = build(arr, s, m);
        rson[now] = build(arr, m + 1, t);
        val[now] = val[lson[now]] + val[rson[now]];
        return now;
    }

    int clone_node(int orig)
    {
        ++cur_idx;
        val[cur_idx] = val[orig];
        mark[cur_idx] = mark[orig];
        lson[cur_idx] = lson[orig];
        rson[cur_idx] = rson[orig];
        return cur_idx;
    }

    int update(int l, int r, int c, int s, int t, int p)
    {
        int now = clone_node(p);
        val[now] += (min(r, t) - max(l, s) + 1) * c;
        if (l <= s && t <= r)
        {
            mark[now] += c;
            return now;
        }
        int m = (s + t) / 2;
        if (l <= m)
            lson[now] = update(l, r, c, s, m, lson[now]);
        if (r > m)
            rson[now] = update(l, r, c, m + 1, t, rson[now]);
        return now;
    }

    int query(int l, int r, int s, int t, int p, int mk = 0)
    {
        if (l <= s && t <= r)
            return val[p] + mk * (t - s + 1);
        int m = (s + t) / 2, ans = 0;
        if (l <= m)
            ans += query(l, r, s, m, lson[p], mk + mark[p]);
        if (r > m)
            ans += query(l, r, m + 1, t, rson[p], mk + mark[p]);
        return ans;
    }
};

模板题:https://acm.hdu.edu.cn/showproblem.php?pid=4348

该题数据较大需要开 long long,同时该题内存限制比较极限(64MB),需要注意节省内存。

2 可持久化权值线段树(主席树)

2.1 普通线段树和权值线段树

首先,对权值线段树这个术语做一下解释。

对于普通线段树,我们都知道维护的是数组区间信息,例如区间和、区间最大 / 小值,维护的内容是数据本身。

而对于权值线段树,维护的是数组区间内数的个数信息,例如 $1$ 的个数、$2$ 的个数,维护的内容相当于许多桶。

因此,对比总结一下:

  • 普通线段树:维护信息,按个数开空间,维护具体信息。
  • 权值线段树:维护桶,按值域(可离散化处理),维护个数。

2.2 问题引入

给定 $n$ 个整数构成的序列 $a$,将对于指定的闭区间 $[l,r]$ 查询其区间内的第 $k$ 小值。

该问题应该是主席树最经典的应用了。

2.3 原理

对于一个权值线段树,储存的是 $1,2,\dots,n$ 的个数。每个节点储存的都是一个值域内数的个数,对于一个节点,如果它维护的值域范围是 $[low,high]$,令中点为 $mid=(low+high)/2$,那么它的左儿子维护的范围为 $[low,mid]$,右儿子维护的范围为 $(mid,high]$.

我们先不想具体实现,假设我们已经有了一个可持久化权值线段树,如何解决区间第 $k$ 小问题。可以按照以下方式完成:

对于一个序列 $a_1,a_2,\dots,a_n$,我们可以按从左到右的顺序,依次将其加入可持久化权值线段树。那么区间 $[l,r]$ 的数插入时,对应的历史版本便是版本 $l$ 到版本 $r$.

如果我们要查询区间 $[l,r]$ 内第 $k$ 大的数,那么我们从根节点开始,根节点代表处于整个值域 $[low,high]$ 的数的个数。

  • 先找根节点的左儿子的 $r$ 版本的值,即代表 $r$ 版本时处于 $[low,mid]$ 的数的个数,记为 $sum_r$.
  • 再找根节点的左儿子的 $l - 1$ 版本的值,即代表 $l-1$ 版本时处于 $[low,mid]$ 的数的个数,记为 $sum_{l-1}$.

那么 $sum_r-sum_{l-1}$ 的意义便是,区间 $[l,r]$ 内大小处于 $[low,mid]$ 的数的个数,记为 $x$.

  • 如果 $x\geq k$,说明 $[low,mid]$ 范围包含的范围比我们要找的范围要大,因此我们递归进入左子节点更细致地查找。
  • 如果 $x<k$,说明 $[low,mid]$ 范围包含了 $[l,r]$ 内前 $x$ 小的数,我们要找的是前 $k$ 小的数,那么我们应该再递归进入右子节点找前 $k-x$ 小的数,就相当于原区间的第 $k$ 小的数。

按上面这种方案递归下去,最终一定会到 $low=high$ 的状态,这个状态便是找到答案,第 $k$ 小的数就是 $low$,递归终止。

2.4 实现

可持久化权值线段树的实现与可持久化线段树几乎一致,唯一区别就是建树方式和更新方式。

建树时不传入初始数组,值全部初始化为 $0$,另外线段树大小不是数组长度,而是值域大小。

更新代表的是加入数 $x$,需要将所有值域包含它的节点 $+1$.

还有至关重要的一点便是离散化,int 的值域在 $10^9$ 等级,如果我们要开一个大小为 $10^9$ 的可持久化权值线段树,那空间也是要爆的。但是考虑到一般题目的数的个数都较少,在 $10^6$ 左右,那么我们可以用离散化,仅储存出现过的数,不考虑没出现过的数。离散化的方式可以对原数据排序、去重后二分完成,具体可以参考代码实现。

namespace hjt
{
    /* ### array index must start from ONE ### */

    constexpr int MAXN = 1e6;
    int n;
    int sum[(MAXN << 5) + 10], lson[(MAXN << 5) + 10], rson[(MAXN << 5) + 10], cur_idx = 0;
    int root[MAXN], cur_ver = 0;

    int build(int s, int t)
    {
        int now = ++cur_idx;
        if (s == t)
        {
            sum[now] = 0;
            return now;
        }
        int m = (s + t) / 2;
        lson[now] = build(s, m);
        rson[now] = build(m + 1, t);
        return now;
    }

    int clone_node(int orig)
    {
        ++cur_idx;
        sum[cur_idx] = sum[orig] + 1;
        lson[cur_idx] = lson[orig];
        rson[cur_idx] = rson[orig];
        return cur_idx;
    }

    int update(int x, int s, int t, int p)
    {
        int now = clone_node(p);
        if (s == t)
            return now;
        int m = (s + t) / 2;
        if (x <= m)
            lson[now] = update(x, s, m, lson[now]);
        else
            rson[now] = update(x, m + 1, t, rson[now]);
        return now;
    }

    int query(int x, int l, int r, int s, int t)
    {
        if (s == t)
            return s;
        int delt = sum[lson[r]] - sum[lson[l]];
        int m = (s + t) / 2;
        if (x <= delt)
            return query(x, lson[l], lson[r], s, m);
        else
            return query(x - delt, rson[l], rson[r], m + 1, t);
    }
};

solve() 函数也附上,防止不知道怎么用,对应的是模板题:https://www.luogu.com.cn/problem/P3834

void solve()
{
    int n, m;
    cin >> n >> m;
    vector<int> a(n + 10);
    for (int i = 1; i <= n; i++)
        cin >> a[i];
    auto b = a;
    sort(b.begin() + 1, b.begin() + 1 + n);
    int uni = unique(b.begin() + 1, b.begin() + 1 + n) - b.begin() - 1;
    hjt::root[0] = hjt::build(1, uni);
    for (int i = 1; i <= n; i++)
    {
        int t = lower_bound(b.begin() + 1, b.begin() + 1 + uni, a[i]) - b.begin();
        hjt::root[i] = hjt::update(t, 1, m, hjt::root[i - 1]);
    }
    for (int i = 1; i <= m; i++)
    {
        int l, r, k;
        cin >> l >> r >> k;
        int t = hjt::query(k, hjt::root[l - 1], hjt::root[r], 1, m);
        cout << b[t] << endl;
    }
}
文章目录