9.2 KiB
AcWing
1171
. 距离
一、题目描述
给出 n
个点的一棵树,多次询问 两点之间的 最短距离。
注意:
- 边是 无向 的。
- 所有节点的编号是
1,2,…,n
。
输入格式
第一行为两个整数 n
和 m
。n
表示点数,m
表示询问次数;
下来 n−1
行,每行三个整数 x,y,k
,表示点 x
和点 y
之间存在一条边长度为 k
;
再接下来 m
行,每行两个整数 x,y
,表示询问点 x
到点 y
的最短距离。
树中结点编号从 1
到 n
。
输出格式
共 m
行,对于每次询问,输出一行询问结果。
数据范围
2≤n≤10^4
1≤m≤2×10^4
0<k≤100
1≤x,y≤n
输入样例1:
2 2
1 2 100
1 2
2 1
输出样例1:
100
100
输入样例2:
3 2
1 2 10
3 1 15
1 2
3 2
输出样例2:
10
25
二、解题思路
此题就是模板基础上的简单扩展,x,y
到 lca(x,y)=z
的最短距离,可以转化为源点(任意点均可)到两个节点的距离和,再减去2
倍的到LCA(A,B)
的距离,如图:
\large dist[x]+dist[y]-2*dist[z]
Code
倍增
#include <bits/stdc++.h>
using namespace std;
const int N = 20010, M = 40010;
int n, m;
int f[N][16], depth[N];
int dist[N]; // 距离1号点的距离
// 邻接表
int e[M], h[N], idx, w[M], ne[M];
void add(int a, int b, int c) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
void bfs() {
// 1号点是源点
depth[1] = 1;
queue<int> q;
q.push(1);
while (q.size()) {
int u = q.front();
q.pop();
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (!depth[v]) {
q.push(v);
depth[v] = depth[u] + 1;
dist[v] = dist[u] + w[i];
f[v][0] = u; // 父亲大人
for (int k = 1; k <= 15; k++) // 记录倍增数组
f[v][k] = f[f[v][k - 1]][k - 1];
}
}
}
}
// 最近公共祖先
int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
// 对齐
for (int k = 15; k >= 0; k--)
if (depth[f[a][k]] >= depth[b])
a = f[a][k];
if (a == b) return a;
// 齐步走
for (int k = 15; k >= 0; k--)
if (f[a][k] != f[b][k])
a = f[a][k], b = f[b][k];
// 返回父亲
return f[a][0];
}
int main() {
memset(h, -1, sizeof h);
scanf("%d %d", &n, &m);
int a, b, c;
// n-1条边
for (int i = 1; i < n; i++) {
scanf("%d %d %d", &a, &b, &c);
add(a, b, c), add(b, a, c);
}
bfs();
while (m--) {
scanf("%d %d", &a, &b);
int t = lca(a, b);
int ans = dist[a] + dist[b] - dist[t] * 2;
printf("%d\n", ans);
}
return 0;
}
三、Tarjan
算法计算LCA
本题考察LCA
的Tarjan
算法。Tarjan
算法是一个 离线算法,一次性读入,计算后再一次性输出,算法的时间复杂度是O(n + m)
。
算法原理
设x
和y
的LCA
是r
(暂时假设r、x、y
是三个不同的节点),则x
和y
一定处于以r
为根的不同子树中,并且可以得出:处在r
不同子树中的任意两个节点的LCA
都是r
,所以在遍历以r
为根的子树时,只要能够判断出两个节点分处于r
的不同子树中,就可以将这两个节点的LCA
标记为r
。
如何判断x
和y
是否处在r
的不同子树中呢?在对以r
为根的子树做dfs
的过程中,如果y
所在的子树已经遍历完了,之后又遍历到x
时,就可以说明x
和y
不在同一棵子树了。
对树的节点进行状态划分:
0
:还未遍历到的节点1
:该节点已经遍历到了,但是其子树还没有完成遍历回溯完2
:该节点以及其子树均已遍历回溯完
注:
2
这个状态在代码实现中被省略,没用上
在dfs
过程中,第一次遍历到r
时,r
的状态转化为1
,并且,r
的祖先节点的状态也都是1
。当y
所在的子树全部遍历回溯完后,y
到r
的路径中,除了r
以外的其他节点的状态均是2
。
换言之,x
和y
的LCA
就是y
向上回溯到第一个状态为1
的节点。
dfs
遍历完y
所在的子树并且遍历完x
及其子树时各节点的状态如上图所示。此时,x
的子树刚刚全部遍历回溯完成,然后发现y
的状态是2
,于是y
向上回溯,发现了第一个标记为状态1
的r
节点,也就是x
和y
的LCA
节点。原理也就是之前所说的,y
所在的子树遍历完了,但是LCA
节点r
状态肯定还是1
,因为r
还有其他子树没有遍历完,后面再遍历到x
所在的子树时,一方面就说明了x
和y
在r
的不同子树中,另一方面也定位到了x
和y
分属不同子树的根节点r
。
为了提高回溯查找LCA
的效率,可以 使用并查集优化,即一个节点状态转化为2
时,就可以将其合并到其父节点所在的集合中,这样一来,当y
所在的子树全部变为状态2
时,他们也都被合并到r
所在的集合了,就有了y
所在的并查集的根结点就是r
,也就是x
和y
的LCA
节点。
特殊情况:r
和x
重合,即x
与y
的LCA
就是x
,此时在遍历完x
的所有子树后,x
的状态即将转化为2
时,y
也被合并到以x
为根的并查集中了,此时x
就是LCA
节点。所以我们可以在x
的子树均已遍历回溯完成之际,对x
与状态为2
的y
节点求LCA
。
综上所述,lca(x,y)=find(y)
,其中find
函数就是并查集的查找当前集合根节点的函数。并且如果要求x
与y
之间的距离:
\large res[id] = dist[u] + dist[y] - 2 * dist[r]
注意:并查集的合并操作一定要在当前节点的所有子树都已经遍历回溯完成的情况下,所以要写在
tarjan
函数调用的后面,否则像r
节点还没有遍历回溯完就被合并到了r
的父节点所在的集合,后面再对y
求并查集的根节点时就不会返回r
节点了,就会引起错误。
#include <bits/stdc++.h>
using namespace std;
const int N = 10010, M = N << 1;
typedef pair<int, int> PII;
// 查询数组,first:对端节点号,second:问题序号
// 比如:q[2]={5,10} 表示10号问题,计算2和5之间的最短距离
vector<PII> query[N];
int dist[N]; // dist[u]记录从出发点S到u的距离
int res[M]; // 结果数组,有多少个问题就有多少个res[i]
// 链式前向星
int e[M], h[N], idx, w[M], ne[M];
void add(int a, int b, int c = 0) {
e[idx] = b, ne[idx] = h[a], w[idx] = c, h[a] = idx++;
}
// 并查集
int p[N];
int find(int x) {
if (x != p[x]) p[x] = find(p[x]);
return p[x];
}
int st[N]; // 0:未入栈, 1:在栈中, 2:已出栈
void tarjan(int u) {
// ① 标识u已访问
st[u] = 1;
// ② 枚举与u临边相连并且没有访问过的点
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (!st[v]) {
// 扩展:更新距离
dist[v] = dist[u] + w[i];
// 深搜
tarjan(v);
// ③ v加入u家族
p[v] = u;
}
}
// ④ 枚举已完成访问的点,记录lca或题目要求的结果
for (auto q : query[u]) {
int v = q.first, id = q.second;
if (st[v]) res[id] = dist[u] + dist[v] - 2 * dist[find(v)];
}
}
int main() {
int n, m; // n个结点,m次询问
scanf("%d%d", &n, &m);
memset(h, -1, sizeof h); // 初始化链式前向星
for (int i = 1; i <= n; i++) p[i] = i; // 并查集初始化
for (int i = 1; i < n; i++) { // 树有n-1条边
int a, b, c;
scanf("%d%d%d", &a, &b, &c);
add(a, b, c), add(b, a, c); // 无向图
}
// Tarjan算法是离线算法,一次性读入所有的问题,最终一并回答
for (int i = 0; i < m; i++) { // m个询问
int a, b;
scanf("%d%d", &a, &b); // 表示询问点 a 到点 b 的最短距离
query[a].push_back({b, i}), query[b].push_back({a, i}); // 不知道谁先被遍历 所以正反都记一下着
}
// tarjan算法求LCA
tarjan(1);
// 回答m个问题
for (int i = 0; i < m; i++) printf("%d\n", res[i]);
return 0;
}