树分治
# 树分治
# 点分治
点分治是一种思想,通常需要配合其他算法来达到目的
# 实现
类似于 cdq 分治思想,我们对于树上的点对问题,把树按照重心分治,先考虑不同子树之间对于答案的影响,然后递归处理其子树

使用一次 dfs 找到树的重心
function<int(int)> get_root = [&] (int u) -> int{
int res = 0;
int rt = u;
function<void(int, int)> get_siz = [&] (int u, int fa) {
siz[u] = 1;
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
get_siz(v, u);
siz[u] += siz[v];
}
};
function<void(int, int)> dfs = [&] (int u, int fa) {
max_son[u] = 0;
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
dfs(v, u);
max_son[u] = max(max_son[u], siz[v]);
}
max_son[u] = max(max_son[u], siz[rt] - siz[u]);
if (res == 0 || max_son[u] < max_son[res])
res = u;
};
get_siz(u, 0);
dfs(u, 0);
return res;
};
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
我这里多用了一次 dfs 来获得这个子树的大小(在后面求重心的时候需要用),但实际上这个可以预处理出来
这里我用一道例题来讲解点分治
洛谷 P3806 【模板】点分治 1 (opens new window)
给定一颗有 个点的树,询问树上距离为 的点对是否存在
暴力求解是 的,考虑点分治做法
我们考虑如何求一个以 为根的节点的树内的点对个数,记为 calc(u)
首先找到这颗子树的重心记作 然后有点类似于 cdq 分治的思想,去计算以 为根的子树之间的点对,也就是说这个点对之间的路径是经过 的
那么 solve(root)
怎么写呢,需要构造 表示 节点到 的距离,可以通过一次 dfs 计算出,然后用双指针算法求出 的点对数量
但是这里可以会把在同一棵子树的点对也计算上,根据容斥原理,枚举每个子树,把子树内重复计算的部分减去就好了
计算好通过 的点对个数之后,继续分治树,去计算各个子树内部的答案
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int main() {
freopen ("P3806.in", "r", stdin);
freopen ("P3806.out", "w", stdout);
ios::sync_with_stdio(false);
int n, m; cin >> n >> m;
vector<vector<pair<int, int>>> g(n + 1);
for (int i = 1; i < n; i++) {
int u, v, w; cin >> u >> v >> w;
g[u].push_back({v, w});
g[v].push_back({u, w});
}
vector<int> siz(n + 1, 0), max_son(n + 1, 0), vis(n + 1, 0);
function<int(int)> get_root = [&] (int u) -> int{
int res = 0;
int rt = u;
function<void(int, int)> get_siz = [&] (int u, int fa) {
siz[u] = 1;
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
get_siz(v, u);
siz[u] += siz[v];
}
};
function<void(int, int)> dfs = [&] (int u, int fa) {
max_son[u] = 0;
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
dfs(v, u);
max_son[u] = max(max_son[u], siz[v]);
}
max_son[u] = max(max_son[u], siz[rt] - siz[u]);
if (res == 0 || max_son[u] < max_son[res])
res = u;
};
get_siz(u, 0);
dfs(u, 0);
return res;
};
function<vector<int>(int, int)> get_dis = [&](int u, int add) -> vector<int>{
vector<int> res;
function<void(int, int, int)> dfs = [&] (int u, int fa, int d) {
res.push_back(d);
for (auto [v, w] : g[u]) {
if (v == fa || vis[v]) continue;
dfs(v, u, d + w);
}
};
dfs(u, 0, add);
return res;
};
function<ll(int, int, int)> solve = [&](int u, int k, int add) -> ll {
auto dis = get_dis(u, add);
sort(dis.begin(), dis.end());
vector<pair<int, int>> cnt;
for (auto x : dis) {
if (cnt.empty() || cnt.back().first != x) cnt.push_back({x, 1});
else cnt.back().second++;
}
int res = 0;
for (auto [x, y] : cnt) {
if (k % 2 == 0 && x == k / 2) res += 1ll * y * (y - 1) / 2;
}
for (int i = 0, j = (int)cnt.size() - 1; i < j; i++) {
while (j > i && cnt[i].first + cnt[j].first > k) j--;
if (j > i && cnt[i].first + cnt[j].first == k) res += 1ll * cnt[i].second * cnt[j].second;
}
return res;
};
function<ll(int, int)> dfs = [&](int u, int k) -> ll {
int root = get_root(u);
ll res = solve(root, k, 0);
vis[root] = 1;
for (auto [v, w] : g[root]) {
if (vis[v]) continue;
res -= solve(v, k, w);
}
for (auto [v, w] : g[root]) {
if (vis[v]) continue;
res += dfs(v, k);
}
return res;
};
while (m--) {
vis.assign(n + 1, 0);
int k; cin >> k;
ll ans = dfs(1, k);
// cout << ans << '\n';
cout << (ans ? "AYE" : "NAY") << '\n';
}
cout << (int)clock()/CLOCKS_PER_SEC << "s\n";
return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
PS:代码由于 function
用多了所以常数巨大
上次更新: 2025/04/08, 18:03:31