前のページ

树链剖分笔记

树剖剖所有~

何为树链剖分

树作为一种性质丰富的图,在 OI 中有非常广泛的命题空间。我们有这不少对树上操作的手段,但是并没有一种直接的数据结构可以处理树上的信息。树链剖分通过通过一种特殊的链剖分技巧,将对树上任意一条路径的操作转化为对数段序列上的操作,对于序列,我们就有线段树、树状数组等数据结构来处理了。

来看一个经典问题:给定一棵 $n$ 个节点的树,节点带权值,需要支持两种操作:

  1. 修改 $u$ 到 $v$ 的路径上所有节点的权值
  2. 查询 $u$ 到 $v$ 的路径上所有节点的权值和

如果这棵树是一条链,那直接可以丢到序列上来维护了,但很可惜这棵树的形态任意。我们可以用暴力的手段:找到 $lca(u, v)$,分别从 $u, v$ 向上走,复杂度是 $O(n)$ 的。

为了处理这种问题,树链剖分应运而生。通俗的说,树链剖分的思想是 化树为链,将树上任意的路径问题转化为序列上的区间问题。

树链剖分怎么做

我们常说的树剖,指的是重链剖分,在开始实现树剖之前我们首先引入几个概念:

  • 重儿子:对于任意一个非叶子节点,其子节点中子树最大(子树结点最多)的儿子。若有多个,任取其一皆可;若没有子节点,则无重儿子
  • 轻儿子:一个节点除重儿子之外的子节点
  • 重边:连接一个节点与其重儿子的边
  • 轻边:连接一个节点与其轻儿子的边
  • 重链:由若干条首尾相接的重边连接而成的路径。特别地,若从一个轻儿子节点出发,则这个轻儿子也算一条重边的起点

偷一张 OI-Wiki 被刨到包浆的图:

补充:落单的节点自身视为一条重链。

通过这些定义,我们可以将树中所有的节点划分到唯一的重链中。

树剖的实现依赖于两个 DFS 和后续的数据结构维护,其中第一个 DFS 处理出每个节点的 父节点、深度、子树大小和重儿子。OI-Wiki 中的 void dfs1(int u) 会导致 dfs2 时 stack-overflow,避坑。

// fa(u): u 的父亲;dep(u): u 在树上的深度;siz(u): u 的子树的节点个数;son(u): 表示节点 u 的重儿子
void dfs1(int u) {
    son[u] = -1, siz[u] = 1;
    for (auto v : G[u]) {
        if (v == fa[u]) continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs1(v);
        siz[u] += siz[v];
        if (son[u] == -1 || siz[v] > siz[son[u]])
            son[u] = v;
    }
}

第二个 DFS 处理出 每个节点所在重链的顶点,以及每个节点的 DFS 序。注意在遍历时需要优先遍历重儿子,保证同一重链上的节点的 DFS 序是连续的。

// top(u): u 所在重链的顶部那个深度最小的节点;dfn(u): u 的 DFS 序,即在线段树中的编号;rnk(u): DFS 序所对应的节点编号,rnk(dfn(u)) = u
void dfs2(int u, int t) {
    top[u] = t;
    tot++;
    dfn[u] = tot;
    rnk[tot] = u;
    if (son[u] == -1) return;
    dfs(son[u], t); // 重儿子优先
    for (auto v : G[u]) {
        if (v != son[u] && v != fa[u])
            dfs2(v, v);
    }
}

操作树剖结果

回到我们最初提出的问题,对于路径 $(u, v)$ 的修改和查询,我们可以利用处理出的 $top$ 像倍增 LCA 过程一样向上跳。对于路径 $(u, v)$:

  1. 选择 $u, v$ 中 $top$ 不相同且深度较深的节点,比如 $u$
  2. $u$ 所在的重链 $(top(u), \dots, u)$。由于这条链上的 $dfn$ 是连续的,所以在线段树上直接操作 $[dfn(top(u)), dfn(u)]$ 整个区间
  3. 之后将 $u$ 跳到 $fa(top(u))$,这样子 $u$ 就跳出了原本所在的重链而来到链顶的父亲节点
  4. 重复这个过程,直到 $u, v$ 处在同一条重链上,此时 $top(u) = top(v)$
  5. 在同一条链上的 $u, v$ 之间的路径也对应 $dfn$ 中的一个连续区间 $[dfn(u), dfn(v)]$,再平推一次线段树的操作

如果是对于整个子树进行操作,由于 $dfn$ 的定义,不难得到所有节点的 $dfn$ 恰好构成一个连续的区间 $[dfn(u), dfn(u) + siz(u) - 1]$,因此也可以转换成一次线段树的区间操作。

补充一下树剖的额外用法:求 LCA。不断向上跳重链,当跳到同一条重链上时,深度较小的节点即为 LCA。向上跳重链时首先要先跳重链顶端深度较大的一个。

int lca(int u, int v) {
    while (top[u] != top[v]) {
        if (dep[top[u]] > dep[top[v]])
            u = fa[top[u]];
        else
            v = fa[top[v]];
    }
    return (dep[u] > dep[v] ? v : u);
}

练题

树剖离不开练。

LG P3384 【模板】重链剖分/树链剖分

年轻人的第一份树剖代码,裸的代码实现一下上面的需求罢了。

#include <bits/stdc++.h>

#define int long long

constexpr int N = 1e5 + 7;

int n, m, r, P;
int tot;
int w[N];
int fa[N], dep[N], siz[N], son[N];
int top[N], dfn[N], rnk[N];

std::vector<int> G[N];

struct SegTree {
    struct node {
        int val, add;
    } tr[N << 2];

    #define ls(o) (o << 1)
    #define rs(o) (o << 1 | 1)

    void maketag(int o, int l, int r, int v) {
        tr[o].add = (tr[o].add + v) % P;
        tr[o].val = (tr[o].val + (r - l + 1) * v) % P;
    }

    void pushup(int o) {
        tr[o].val = (tr[ls(o)].val + tr[rs(o)].val) % P;
    }

    void pushdown(int o, int l, int r) {
        int mid = (l + r) >> 1;
        maketag(ls(o), l, mid, tr[o].add);
        maketag(rs(o), mid + 1, r, tr[o].add);
        tr[o].add = 0;
    }

    void build(int o, int l, int r) {
        if (l == r) {
            tr[o] = {w[rnk[l]], 0};
            return;
        }
        int mid = (l + r) >> 1;
        build(ls(o), l, mid);
        build(rs(o), mid + 1, r);
        pushup(o);
    }

    void update(int o, int l, int r, int ql, int qr, int v) {
        if (ql <= l && r <= qr) {
            maketag(o, l, r, v);
            return;
        }
        int mid = (l + r) >> 1;
        pushdown(o, l, r);
        if (ql <= mid)
            update(ls(o), l, mid, ql, qr, v);
        if (qr > mid)
            update(rs(o), mid + 1, r, ql, qr, v);
        pushup(o);
    }

    int query(int o, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr)
            return tr[o].val % P;
        int mid = (l + r) >> 1, res = 0;
        pushdown(o, l, r);
        if (ql <= mid)
            res = (res + query(ls(o), l, mid, ql, qr)) % P;
        if (qr > mid)
            res = (res + query(rs(o), mid + 1, r, ql, qr)) % P;
        return res % P;
    }
} seg;

void dfs1(int u) {
    son[u] = -1, siz[u] = 1;
    for (auto v : G[u]) {
        if (v == fa[u]) continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs1(v);
        siz[u] += siz[v];
        if (son[u] == -1 || siz[v] > siz[son[u]])
            son[u] = v;
    }
}

void dfs2(int u, int t) {
    tot++;
    top[u] = t, dfn[u] = tot, rnk[tot] = u;
    if (son[u] == -1) return;
    dfs2(son[u], t);
    for (auto v : G[u]) {
        if (v != son[u] && v != fa[u])
            dfs2(v, v);
    }
}

void solve1(int x, int y, int z) {
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]])
            std::swap(x, y);
        seg.update(1, 1, n, dfn[top[x]], dfn[x], z);
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        std::swap(x, y);
    seg.update(1, 1, n, dfn[x], dfn[y], z);
}

int solve2(int x, int y) {
    int res = 0;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]])
            std::swap(x, y);
        res = (res + seg.query(1, 1, n, dfn[top[x]], dfn[x])) % P;
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        std::swap(x, y);
    res = (res + seg.query(1, 1, n, dfn[x], dfn[y])) % P;
    return res;
}

void solve3(int x, int z) {
    seg.update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, z);
}

int solve4(int x) {
    return seg.query(1, 1, n, dfn[x], dfn[x] + siz[x] - 1);
}

signed main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::cin >> n >> m >> r >> P;
    for (int i = 1; i <= n; i++) {
        std::cin >> w[i];
    }
    for (int i = 1, x, y; i < n; i++) {
        std::cin >> x >> y;
        G[x].push_back(y);
        G[y].push_back(x);
    }
    dfs1(r);
    dfs2(r, r);
    seg.build(1, 1, n);
    for (int i = 1, t, x, y, z; i <= m; i++) {
        std::cin >> t;
        if (t == 1) { std::cin >> x >> y >> z; solve1(x, y, z); }
        if (t == 2) { std::cin >> x >> y; std::cout << solve2(x, y) << "\n"; }
        if (t == 3) { std::cin >> x >> z; solve3(x, z); }
        if (t == 4) { std::cin >> x; std::cout << solve4(x) << "\n"; }
    }
    return 0;
}

CF343D Water Tree

模板题,但是注意这个线段树的 lazy_tag 初始,因为打的标记就是 $0/1$,所以起始建树为 $-1$。还是不够熟练,记得在线段树上修改要用 $dfn(u)$,两次操作非常简单,没什么好说的。

#include <bits/stdc++.h>

using i64 = long long;

constexpr int N = 5e5 + 7;

int n, q;
int tot;
int fa[N], dep[N], siz[N], son[N];
int top[N], dfn[N], rnk[N];

std::vector<int> G[N];

struct SegTree {
    struct node {
        int val, lzy;
    } tr[N << 2];

    #define ls(o) (o << 1)
    #define rs(o) (o << 1 | 1)

    void build(int o, int l, int r) {
        if (l == r) {
            tr[o] = {0, -1};
            return;
        }
        int mid = (l + r) >> 1;
        build(ls(o), l, mid);
        build(rs(o), mid + 1, r);
    }

    void pushdown(int o) {
        if (tr[o].lzy != -1) {
            tr[ls(o)] = {tr[o].lzy, tr[o].lzy};
            tr[rs(o)] = {tr[o].lzy, tr[o].lzy};
            tr[o].lzy = -1;
        }
    }

    void update(int o, int l, int r, int ql, int qr, int v) {
        if (ql <= l && r <= qr) {
            tr[o] = {v, v};
            return;
        }
        int mid = (l + r) >> 1;
        pushdown(o);
        if (ql <= mid)
            update(ls(o), l, mid, ql, qr, v);
        if (qr > mid)
            update(rs(o), mid + 1, r, ql, qr, v);
    }

    int query(int o, int l, int r, int x) {
        if (l == r)
            return tr[o].val;
        pushdown(o);
        int mid = (l + r) >> 1;
        if (x <= mid)
            return query(ls(o), l, mid, x);
        else
            return query(rs(o), mid + 1, r, x);
    }
} seg;

void dfs1(int u) {
    son[u] = -1, siz[u] = 1;
    for (auto v : G[u]) {
        if (v == fa[u]) continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs1(v);
        siz[u] += siz[v];
        if (son[u] == -1 || siz[v] > siz[son[u]])
            son[u] = v;
    }
}

void dfs2(int u, int t) {
    tot++;
    top[u] = t, dfn[u] = tot, rnk[tot] = u;
    if (son[u] == -1) return;
    dfs2(son[u], t);
    for (auto v : G[u]) {
        if (v != son[u] && v != fa[u])
            dfs2(v, v);
    }
}

void solve1(int x) {
    seg.update(1, 1, n, dfn[x], dfn[x] + siz[x] - 1, 1);
}

void solve2(int x) {
    while (x) {
        seg.update(1, 1, n, dfn[top[x]], dfn[x], 0);
        x = fa[top[x]];
    }
}

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::cin >> n;
    for (int i = 1, u, v; i < n; i++) {
        std::cin >> u >> v;
        G[u].push_back(v);
        G[v].push_back(u);
    }
    dfs1(1);
    dfs2(1, 1);
    seg.build(1, 1, n);
    std::cin >> q;
    for (int i = 1, c, v; i <= q; i++) {
        std::cin >> c >> v;
        if (c == 1) { solve1(v); }
        if (c == 2) { solve2(v); }
        if (c == 3) { std::cout << seg.query(1, 1, n, dfn[v]) << "\n"; }
    }
    return 0;
}

LG P2590 [ZJOI2008] 树的统计 & LG P3178 [HAOI2015] 树上操作

简单题,比模板还简单。

LG P4114 Qtree1

原题来自 SPOJ,边权转点权,注意到在树上一个点只与其父亲有且仅有一条边相连,考虑把一条边的权值丢到其一条边的儿子节点上。对于根,由于没有相连的父亲节点边,所以在查询 $sum$ 时将点权设为 $0$;在查询 $\min/\max$ 时将点权设为 $inf, -inf$ 即可。同时注意到在修改和查询的过程中不能碰到节点的 $LCA$,所以在处理的最后其区间范围应为 $[dfn(u + 1), dfn(v)]$。

边权转点权的代码操作主要在 void dfs(int u, int pa, int k) 和建树后的 for 循环里。

#include <bits/stdc++.h>

#define int long long

constexpr int N = 3e5 + 7;
constexpr int inf = 1e9;

int n, ncnt;
std::string s;
int w[N];
int fa[N], dep[N], siz[N], son[N];
int top[N], dfn[N], rnk[N];

std::vector<std::pair<int, int>> G[N];

struct edge {
    int u, v, w, x;
} E[N << 1];

struct SegTree {
    struct node {
        int val, mxn;
    } tr[N << 2];

    #define ls(o) (o << 1)
    #define rs(o) (o << 1 | 1)

    void pushup(int o) {
        tr[o].mxn = std::max(tr[ls(o)].mxn, tr[rs(o)].mxn);
    }

    void build(int o, int l, int r) {
        if (l == r) {
            tr[o] = {w[rnk[l]], w[rnk[l]]};
            return;
        }
        int mid = (l + r) >> 1;
        build(ls(o), l, mid);
        build(rs(o), mid + 1, r);
        pushup(o);
    }

    void update(int o, int l, int r, int p, int v) {
        if (l == r) {
            tr[o] = {v, v};
            return;
        }
        int mid = (l + r) >> 1;
        if (p <= mid)
            update(ls(o), l, mid, p, v);
        else
            update(rs(o), mid + 1, r, p, v);
        pushup(o);
    }

    int query(int o, int l, int r, int ql, int qr) {
        if (ql <= l && r <= qr)
            return tr[o].mxn;
        int mid = (l + r) >> 1,
            res = -inf;
        if (ql <= mid)
            res = std::max(res, query(ls(o), l, mid, ql, qr));
        if (qr > mid)
            res = std::max(res, query(rs(o), mid + 1, r, ql, qr));
        return res;
    }
} seg;

void dfs0(int u, int pa, int k) {
    w[u] = k;
    for (auto i : G[u]) {
        auto [v, w] = i;
        if (v == pa)
            continue;
        dfs0(v, u, w);
    }
}

void dfs1(int u) {
    son[u] = -1, siz[u] = 1;
    for (auto i : G[u]) {
        auto [v, w] = i;
        if (v == fa[u])
            continue;
        fa[v] = u;
        dep[v] = dep[u] + 1;
        dfs1(v);
        siz[u] += siz[v];
        if (son[u] == -1 || siz[v] > siz[son[u]])
            son[u] = v;
    }
}

void dfs2(int u, int t) {
    ncnt++;
    top[u] = t, dfn[u] = ncnt, rnk[ncnt] = u;
    if (son[u] == -1) return;
    dfs2(son[u], t);
    for (auto i : G[u]) {
        auto [v, w] = i;
        if (v != son[u] && v != fa[u])
            dfs2(v, v);
    }
}

int solve(int x, int y) {
    int res = -inf;
    while (top[x] != top[y]) {
        if (dep[top[x]] < dep[top[y]])
            std::swap(x, y);
        res = std::max(res, seg.query(1, 1, n, dfn[top[x]], dfn[x]));
        x = fa[top[x]];
    }
    if (dep[x] > dep[y])
        std::swap(x, y);
    res = std::max(res, seg.query(1, 1, n, dfn[x] + 1, dfn[y]));
    return res;
}

signed main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);

    std::cin >> n;
    for (int i = 1; i < n; i++) {
        std::cin >> E[i].u >> E[i].v >> E[i].w;
        G[E[i].u].push_back({E[i].v, E[i].w});
        G[E[i].v].push_back({E[i].u, E[i].w});
    }
    dfs0(1, 0, -inf);
    dfs1(1);
    dfs2(1, 1);
    seg.build(1, 1, n);
    for (int i = 1; i < n; i++) {
        if (fa[E[i].u] == E[i].v)
            E[i].x = E[i].u;
        else
            E[i].x = E[i].v;
    }
    while (std::cin >> s) {
        int a, b;
        if (s == "DONE")
            break;
        if (s == "CHANGE") {
            std::cin >> a >> b;
            seg.update(1, 1, n, dfn[E[a].x], b);
        }
        if (s == "QUERY") {
            std::cin >> a >> b;
            if (a == b)
                std::cout << "0\n";
            else
                std::cout << solve(a, b) << "\n";
        }
    }
    return 0;
}
「馬鹿なわたしは歌うだけ」
Built with Hugo
Theme exStack modified by Qixyi