You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
2.6 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <bits/stdc++.h>
using namespace std;
const int N = 100010, M = 200010;
int depth[N], f[N][25];
int n, m;
int d[N]; // 差分数组
int ans; // 存答案
const int T = 17;
// 邻接表
int e[M], h[N], idx, ne[M];
void add(int a, int b) {
e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}
// 树上倍增
void bfs() {
queue<int> q;
q.push(1);
depth[1] = 1;
while (q.size()) {
int u = q.front();
q.pop();
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (!depth[v]) {
depth[v] = depth[u] + 1;
q.push(v);
f[v][0] = u;
for (int k = 1; k <= T; k++) f[v][k] = f[f[v][k - 1]][k - 1];
}
}
}
}
// 标准lca
int lca(int a, int b) {
if (depth[a] < depth[b]) swap(a, b);
for (int i = T; i >= 0; i--)
if (depth[f[a][i]] >= depth[b]) a = f[a][i];
if (a == b) return a;
for (int i = T; i >= 0; i--)
if (f[a][i] != f[b][i])
a = f[a][i], b = f[b][i];
return f[a][0];
}
// 差分数组还原
void dfs(int u, int fa) {
for (int i = h[u]; ~i; i = ne[i]) {
int v = e[i];
if (v == fa) continue;
dfs(v, u);
d[u] += d[v];
}
}
int main() {
int a, b;
scanf("%d %d", &n, &m);
memset(h, -1, sizeof h);
for (int i = 1; i < n; i++) { // n-1条边
scanf("%d %d", &a, &b);
add(a, b), add(b, a);
}
// lca的准备动作
bfs();
// 读入附加边
for (int i = 0; i < m; i++) {
scanf("%d %d", &a, &b);
// 树上差分
// d[a]的含义从a->fa这边条多了一个环
// d[b]的含义从b->fb这边条多了一个环
d[a]++, d[b]++;
int p = lca(a, b);
/*
Q:lca(a,b)为什么要减2
A:边差分,每条边是下放到下面的那个点上,用点来表示这个边的。
其实每个点表示的是它向上那条边被覆盖的次数对于lca(a,b)而言由于dfs统计进行前缀和汇总时
是左子树+右子树这样的形式进行汇总的也按同样逻辑处理就会多出2个需要扣除掉。
*/
d[p] -= 2;
}
// 差分数组求前缀和
dfs(1, 0);
// Q:为什么要从2开始
// A:因为1是根1是没有边的边是向上的从2开始才有边
for (int i = 2; i <= n; i++) {
if (d[i] == 0) ans += m;
if (d[i] == 1) ans += 1;
}
// 输出
printf("%d\n", ans);
return 0;
}