Banner image of the blog
站长头像

YaLi Blog

DarkClever的个人博客

文章 评论 标签
8 0 5
分类列表

线段树维护树的直径

日期:2025-10-11浏览:0 编辑

给你两棵树 $A,B$ 并且告诉你 $A$ 的最远点对为 $(a,b)$,$B$ 的最远点对为 $(c,d)$,假设此时用一条边将树 $A,B$ 相连,形成了一颗新的树,那么此时证明新的树的最远点对 $(e,f)$ 满足 $e,f \in {a,b,c,d}$。

我们不难想到进行分类讨论,分为以下三种情况:

  1. 直径的两个端点都在 $A$ 内。
  2. 直径的两个端点都在 $B$ 内。
  3. 直径的一个端点在 $A$ 内,一个端点在 $B$ 内。

1和2不难证明,考虑到端点都在一棵树内则就是原来的直径,对于情况3,我们不难看出直径一定经过了链接 $A$ 和 $B$ 的那条边,也就是说,经过了连接 $A$ 和 $B$ 的那条边的两个端点 $(g,h)$。

不妨再次将 $A$ 和 $B$ 分开来理解,不难发现分别离 $(g,h)$ 最远的那两个点就是贯穿 $A$、$B$ 的直径的端点,而离树上一点最远的一个点一定是树上的端点,所以 $A$ 和 $B$ 组合后的直径的端点一定是 $A$ 和 $B$ 原来的直径的端点

线段树

我们需要找到一个数据结构来维护这个东西,至少需要支持合并的功能,不难想到线段树是一个很好的选择。

我们先将 $A$ 和 $B$ 按照dfs序打成序列,然后对其建立线段树,对于线段树的叶子节点,我们只需要维护树上的一个点,所以直径端点为自己,直径长度为 $0$,合并时,我们需要分别查询6次距离,总复杂度 $\log^2 n$,对于修改,修改一条边的距离,相当于增加整个子树的深度,只需要拿线段树维护一下深度即可。

代码,如下,常数极大。

const int N = 110000;
#define vnt long long
// #define int long long
int n, s;
struct qwq {
    int u, v;
    vnt w;
} b[N];
struct edge {
    int v;
    vnt w;
};
vector<edge> e[N];
int dfs[N], hson[N], dfn[N], sz[N], top[N], rnk[N], fa[N];
vnt dep[N];
namespace TD {
int dn = 0;
void dfs1(int x, int f, vnt dee = 0) {
    dep[x] = dep[f] + dee;
    fa[x] = f;
    sz[x] = 1;
    for(auto xx : e[x]) {
        int v = xx.v;
        vnt w = xx.w;
        if(v != f) {
            dfs1(v, x, w);
            sz[x] += sz[v];
            if(sz[v] >= sz[hson[x]]) hson[x] = v;
        }
    }
}
void dfs2(int x, int tp) {
    top[x] = tp, dfn[x] = ++dn, rnk[dn] = x;
    if(!hson[x]) return;
    dfs2(hson[x], tp);
    for(auto xx : e[x]) {
        int v = xx.v;
        vnt w = xx.w;
        if(v != fa[x] && v != hson[x]) {
            dfs2(v, v);
        }
    }
}
int lca(int x, int y) {
    while(1) {
        if(top[x] == top[y]) {
            if(dep[x] < dep[y]) swap(x, y);
            return y;
        }
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        x = fa[top[x]];
    }
}
}  // namespace TD
using namespace TD;
namespace TsegT {
struct unit {
    int l, r, x, y;

    // l,r 直径左端点,直径右端点,直径长度
    vnt add, len;
    // deep lzy tag
} t[N * 4];
#define ls (pos << 1)
#define rs (pos << 1 | 1)
#define mid ((l + r) >> 1)
vnt getdepth(int x, int pos = 1, int l = 1, int r = n) {
    // cout << pos << " - " << t[pos].add << endl;
    // cout << "nq" << endl;
    if(l == r && l == x) {
        return t[pos].add;
    }
    if(x < l || x > r) return 0;
    if(x >= l && x <= mid) {
        return getdepth(x, ls, l, mid) + t[pos].add;
    } else if(x >= mid + 1 && x <= r) {
        return getdepth(x, rs, mid + 1, r) + t[pos].add;
    }
    // cout << "z" << endl;
    return 0;
}
vnt getdis(int x, int y) {
    int lc = lca(x, y);
    return dep[x] + getdepth(dfn[x]) + dep[y] + getdepth(dfn[y]) - 2 * dep[lc] - 2 * getdepth(dfn[lc]);
}
unit merge(unit A, unit B, unit base) {
    // cout << "merge" << endl;
    if(!A.l) return B;
    if(!B.l) return A;
    unit C = base;
    C.len = 0;
    C.l = min(A.l, B.l);
    C.r = max(A.r, B.r);
    int ep[6] = { 0, A.x, A.y, B.x, B.y };
    for(int i = 1; i <= 4; i++) {
        for(int j = 1; j < i; j++) {
            vnt d = getdis(ep[i], ep[j]);
            // lca logn
            // 3 deep logn
            // log n
            if(d > C.len) {
                C.len = d;
                C.x = ep[i];
                C.y = ep[j];
            }
        }
    }
    // cout << "merged!" << C.x << " " << C.y << " " << C.len << endl;
    return C;
}
void build(int pos = 1, int l = 1, int r = n) {
    // cout << pos << " " << l << " " << r << endl;
    if(l == r) {
        t[pos].l = l;
        t[pos].r = l;
        t[pos].x = rnk[l];
        t[pos].y = rnk[l];
        t[pos].len = 0;
        return;
    }
    build(ls, l, mid);
    build(rs, mid + 1, r);
    t[pos] = merge(t[ls], t[rs], t[pos]);
    return;
}
unit query(int L, int R, int pos = 1, int l = 1, int r = n) {
    if(L <= l && r <= R) return t[pos];
    if(R <= mid) return query(L, R, ls, l, mid);
    if(L > mid) return query(L, R, rs, mid + 1, r);
    return merge(query(L, R, ls, l, mid), query(L, R, rs, mid + 1, r), {});
}
void pushdown(int pos) {
    if(t[pos].add) {
        t[ls].add += t[pos].add;
        t[rs].add += t[pos].add;
        t[pos].add = 0;
    }
}

void range_add(int L, int R, vnt delta, int pos = 1, int l = 1, int r = n) {
    if(L <= l && r <= R) {
        // cout << l << " " << r << " += " << delta << endl;
        t[pos].add += delta;
        return;
    }
    pushdown(pos);
    if(L <= mid) range_add(L, R, delta, ls, l, mid);
    if(R > mid) range_add(L, R, delta, rs, mid + 1, r);
    t[pos] = merge(t[ls], t[rs], t[pos]);
}

}  // namespace TsegT
using namespace TsegT;