线段树维护树的直径
日期:2025-10-11浏览:0
编辑
引
给你两棵树 $A,B$ 并且告诉你 $A$ 的最远点对为 $(a,b)$,$B$ 的最远点对为 $(c,d)$,假设此时用一条边将树 $A,B$ 相连,形成了一颗新的树,那么此时证明新的树的最远点对 $(e,f)$ 满足 $e,f \in {a,b,c,d}$。
我们不难想到进行分类讨论,分为以下三种情况:
- 直径的两个端点都在 $A$ 内。
- 直径的两个端点都在 $B$ 内。
- 直径的一个端点在 $A$ 内,一个端点在 $B$ 内。
1和2不难证明,考虑到端点都在一棵树内则就是原来的直径,对于情况3,我们不难看出直径一定经过了链接 $A$ 和 $B$ 的那条边,也就是说,经过了连接 $A$ 和 $B$ 的那条边的两个端点 $(g,h)$。
不妨再次将 $A$ 和 $B$ 分开来理解,不难发现分别离 $(g,h)$ 最远的那两个点就是贯穿 $A$、$B$ 的直径的端点,而离树上一点最远的一个点一定是树上的端点,所以 $A$ 和 $B$ 组合后的直径的端点一定是 $A$ 和 $B$ 原来的直径的端点
线段树
我们需要找到一个数据结构来维护这个东西,至少需要支持合并的功能,不难想到线段树是一个很好的选择。
我们先将 $A$ 和 $B$ 按照dfs序打成序列,然后对其建立线段树,对于线段树的叶子节点,我们只需要维护树上的一个点,所以直径端点为自己,直径长度为 $0$,合并时,我们需要分别查询6次距离,总复杂度 $\log^2 n$,对于修改,修改一条边的距离,相当于增加整个子树的深度,只需要拿线段树维护一下深度即可。
代码,如下,常数极大。
const int N = 110000;
#define vnt long long
// #define int long long
int n, s;
struct qwq {
int u, v;
vnt w;
} b[N];
struct edge {
int v;
vnt w;
};
vector<edge> e[N];
int dfs[N], hson[N], dfn[N], sz[N], top[N], rnk[N], fa[N];
vnt dep[N];
namespace TD {
int dn = 0;
void dfs1(int x, int f, vnt dee = 0) {
dep[x] = dep[f] + dee;
fa[x] = f;
sz[x] = 1;
for(auto xx : e[x]) {
int v = xx.v;
vnt w = xx.w;
if(v != f) {
dfs1(v, x, w);
sz[x] += sz[v];
if(sz[v] >= sz[hson[x]]) hson[x] = v;
}
}
}
void dfs2(int x, int tp) {
top[x] = tp, dfn[x] = ++dn, rnk[dn] = x;
if(!hson[x]) return;
dfs2(hson[x], tp);
for(auto xx : e[x]) {
int v = xx.v;
vnt w = xx.w;
if(v != fa[x] && v != hson[x]) {
dfs2(v, v);
}
}
}
int lca(int x, int y) {
while(1) {
if(top[x] == top[y]) {
if(dep[x] < dep[y]) swap(x, y);
return y;
}
if(dep[top[x]] < dep[top[y]]) swap(x, y);
x = fa[top[x]];
}
}
} // namespace TD
using namespace TD;
namespace TsegT {
struct unit {
int l, r, x, y;
// l,r 直径左端点,直径右端点,直径长度
vnt add, len;
// deep lzy tag
} t[N * 4];
#define ls (pos << 1)
#define rs (pos << 1 | 1)
#define mid ((l + r) >> 1)
vnt getdepth(int x, int pos = 1, int l = 1, int r = n) {
// cout << pos << " - " << t[pos].add << endl;
// cout << "nq" << endl;
if(l == r && l == x) {
return t[pos].add;
}
if(x < l || x > r) return 0;
if(x >= l && x <= mid) {
return getdepth(x, ls, l, mid) + t[pos].add;
} else if(x >= mid + 1 && x <= r) {
return getdepth(x, rs, mid + 1, r) + t[pos].add;
}
// cout << "z" << endl;
return 0;
}
vnt getdis(int x, int y) {
int lc = lca(x, y);
return dep[x] + getdepth(dfn[x]) + dep[y] + getdepth(dfn[y]) - 2 * dep[lc] - 2 * getdepth(dfn[lc]);
}
unit merge(unit A, unit B, unit base) {
// cout << "merge" << endl;
if(!A.l) return B;
if(!B.l) return A;
unit C = base;
C.len = 0;
C.l = min(A.l, B.l);
C.r = max(A.r, B.r);
int ep[6] = { 0, A.x, A.y, B.x, B.y };
for(int i = 1; i <= 4; i++) {
for(int j = 1; j < i; j++) {
vnt d = getdis(ep[i], ep[j]);
// lca logn
// 3 deep logn
// log n
if(d > C.len) {
C.len = d;
C.x = ep[i];
C.y = ep[j];
}
}
}
// cout << "merged!" << C.x << " " << C.y << " " << C.len << endl;
return C;
}
void build(int pos = 1, int l = 1, int r = n) {
// cout << pos << " " << l << " " << r << endl;
if(l == r) {
t[pos].l = l;
t[pos].r = l;
t[pos].x = rnk[l];
t[pos].y = rnk[l];
t[pos].len = 0;
return;
}
build(ls, l, mid);
build(rs, mid + 1, r);
t[pos] = merge(t[ls], t[rs], t[pos]);
return;
}
unit query(int L, int R, int pos = 1, int l = 1, int r = n) {
if(L <= l && r <= R) return t[pos];
if(R <= mid) return query(L, R, ls, l, mid);
if(L > mid) return query(L, R, rs, mid + 1, r);
return merge(query(L, R, ls, l, mid), query(L, R, rs, mid + 1, r), {});
}
void pushdown(int pos) {
if(t[pos].add) {
t[ls].add += t[pos].add;
t[rs].add += t[pos].add;
t[pos].add = 0;
}
}
void range_add(int L, int R, vnt delta, int pos = 1, int l = 1, int r = n) {
if(L <= l && r <= R) {
// cout << l << " " << r << " += " << delta << endl;
t[pos].add += delta;
return;
}
pushdown(pos);
if(L <= mid) range_add(L, R, delta, ls, l, mid);
if(R > mid) range_add(L, R, delta, rs, mid + 1, r);
t[pos] = merge(t[ls], t[rs], t[pos]);
}
} // namespace TsegT
using namespace TsegT;