Martian148's blog Martian148's blog
首页
  • ICPC 算法笔记
  • ICPC 算法题解
  • 体系结构
  • 高等数学
  • 线性代数
  • 概率论与数理统计
  • 具体数学
  • Martian148的奇思妙想
  • 游记
  • 通识课笔记
关于
  • useful 网站
  • 友情链接
  • 分类
  • 归档

Martian148

一只热爱文科的理科生
首页
  • ICPC 算法笔记
  • ICPC 算法题解
  • 体系结构
  • 高等数学
  • 线性代数
  • 概率论与数理统计
  • 具体数学
  • Martian148的奇思妙想
  • 游记
  • 通识课笔记
关于
  • useful 网站
  • 友情链接
  • 分类
  • 归档
  • 线上赛板子(实时更新)
  • 数据结构

    • 堆
    • 并查集
    • 树状数组
    • 分块
    • 树相关
    • 线段树
    • 平衡树
    • 树链剖分
    • LCT
    • 最近公共祖先
    • 虚树
    • 树分治
      • 点分治
        • 实现
    • K-D Tree
    • 笛卡尔树
    • 珂朵莉树
  • 数学

  • 计算几何

  • 动态规划

  • 图论

  • 字符串

  • 杂项

  • 算法笔记
  • 数据结构
martian148
2024-09-03
目录

树分治

# 树分治

# 点分治

点分治是一种思想,通常需要配合其他算法来达到目的

# 实现

类似于 cdq 分治思想,我们对于树上的点对问题,把树按照重心分治,先考虑不同子树之间对于答案的影响,然后递归处理其子树

image-20240913103835082

使用一次 dfs 找到树的重心

function<int(int)> get_root = [&] (int u) -> int{
        int res = 0;
        int rt = u;
        function<void(int, int)> get_siz = [&] (int u, int fa) {
            siz[u] = 1;
            for (auto [v, w] : g[u]) {
                if (v == fa || vis[v]) continue;
                get_siz(v, u);
                siz[u] += siz[v];
            }
        };
        function<void(int, int)> dfs = [&] (int u, int fa) {
            max_son[u] = 0;
            for (auto [v, w] : g[u]) {
                if (v == fa || vis[v]) continue;
                dfs(v, u);
                max_son[u] = max(max_son[u], siz[v]);
            }
            max_son[u] = max(max_son[u], siz[rt] - siz[u]);
            if (res == 0 || max_son[u] < max_son[res])
                res = u;
        };
        get_siz(u, 0);
        dfs(u, 0);
        return res;
    };
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26

我这里多用了一次 dfs 来获得这个子树的大小(在后面求重心的时候需要用),但实际上这个可以预处理出来

这里我用一道例题来讲解点分治

洛谷 P3806 【模板】点分治 1 (opens new window)

给定一颗有 nnn 个点的树,询问树上距离为 kkk 的点对是否存在

暴力求解是 O(n2)O(n^2)O(n2) 的,考虑点分治做法

我们考虑如何求一个以 uuu 为根的节点的树内的点对个数,记为 calc(u)

首先找到这颗子树的重心记作 rootrootroot 然后有点类似于 cdq 分治的思想,去计算以 rootrootroot 为根的子树之间的点对,也就是说这个点对之间的路径是经过 rootrootroot 的

那么 solve(root) 怎么写呢,需要构造 dis[x]dis[x]dis[x] 表示 xxx 节点到 rootrootroot 的距离,可以通过一次 dfs 计算出,然后用双指针算法求出 dis[i]+dis[j]=kdis[i]+dis[j]=kdis[i]+dis[j]=k 的点对数量

但是这里可以会把在同一棵子树的点对也计算上,根据容斥原理,枚举每个子树,把子树内重复计算的部分减去就好了

计算好通过 rootrootroot 的点对个数之后,继续分治树,去计算各个子树内部的答案

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

int main() {
    freopen ("P3806.in", "r", stdin);
    freopen ("P3806.out", "w", stdout);
    ios::sync_with_stdio(false);
    int n, m; cin >> n >> m;
    vector<vector<pair<int, int>>> g(n + 1);
    for (int i = 1; i < n; i++) {
        int u, v, w; cin >> u >> v >> w;
        g[u].push_back({v, w});
        g[v].push_back({u, w});
    }
    
    vector<int> siz(n + 1, 0), max_son(n + 1, 0), vis(n + 1, 0);

    function<int(int)> get_root = [&] (int u) -> int{
        int res = 0;
        int rt = u;
        function<void(int, int)> get_siz = [&] (int u, int fa) {
            siz[u] = 1;
            for (auto [v, w] : g[u]) {
                if (v == fa || vis[v]) continue;
                get_siz(v, u);
                siz[u] += siz[v];
            }
        };
        function<void(int, int)> dfs = [&] (int u, int fa) {
            max_son[u] = 0;
            for (auto [v, w] : g[u]) {
                if (v == fa || vis[v]) continue;
                dfs(v, u);
                max_son[u] = max(max_son[u], siz[v]);
            }
            max_son[u] = max(max_son[u], siz[rt] - siz[u]);
            if (res == 0 || max_son[u] < max_son[res])
                res = u;
        };
        get_siz(u, 0);
        dfs(u, 0);
        return res;
    };

    function<vector<int>(int, int)> get_dis = [&](int u, int add) -> vector<int>{
        vector<int> res;
        function<void(int, int, int)> dfs = [&] (int u, int fa, int d) {
            res.push_back(d);
            for (auto [v, w] : g[u]) {
                if (v == fa || vis[v]) continue;
                dfs(v, u, d + w);
            }
        };
        dfs(u, 0, add);  
        return res;
    };
    
    function<ll(int, int, int)> solve = [&](int u, int k, int add) -> ll {
        auto dis = get_dis(u, add);
        sort(dis.begin(), dis.end());
        vector<pair<int, int>> cnt;
        for (auto x : dis) {
            if (cnt.empty() || cnt.back().first != x) cnt.push_back({x, 1});
            else cnt.back().second++;
        }
        int res = 0;
        for (auto [x, y] : cnt) {
            if (k % 2 == 0 && x == k / 2) res += 1ll * y * (y - 1) / 2;
        }
        for (int i = 0, j = (int)cnt.size() - 1; i < j; i++) {
            while (j > i && cnt[i].first + cnt[j].first > k) j--;
            if (j > i && cnt[i].first + cnt[j].first == k) res += 1ll * cnt[i].second * cnt[j].second;
        }
        return res;
    };

    function<ll(int, int)> dfs = [&](int u, int k) -> ll {
        int root = get_root(u);
        ll res = solve(root, k, 0);
        vis[root] = 1;
        for (auto [v, w] : g[root]) {
            if (vis[v]) continue;
            res -= solve(v, k, w);
        }
        for (auto [v, w] : g[root]) {
            if (vis[v]) continue;
            res += dfs(v, k);
        }
        return res;
    };

    while (m--) {
        vis.assign(n + 1, 0);
        int k; cin >> k;
        ll ans = dfs(1, k);
        // cout << ans << '\n';
        cout << (ans ? "AYE" : "NAY") << '\n';
    }
    cout << (int)clock()/CLOCKS_PER_SEC << "s\n";
    return 0;
}
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

PS:代码由于 function 用多了所以常数巨大

上次更新: 2025/04/08, 18:03:31
虚树
K-D Tree

← 虚树 K-D Tree→

最近更新
01
Java基础语法
05-26
02
开发环境配置
05-26
03
pink 老师 JavaScript 学习笔记
05-26
更多文章>
Theme by Vdoing | Copyright © 2024-2025 Martian148 | MIT License
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式