3.8 KiB
题目描述
\text{Mila}
种了一棵苹果树,在魔法的浇灌下苹果树茁壮成长,很快就结出了苹果。
但是天有不测风云,一个夜晚,天打五雷轰,苹果树倒了。
在苹果树倒之前,\text{Mila}
忘记记录哪些节点上有苹果了,但是 \text{Mila}
记得,这棵苹果树的根节点是 1
,并且每个节点上最多有 1
个苹果。在树倒之前 \text{Mila}
记录下了一些数字,a_i
表示 i
号节点的子树内的苹果数量不超过 a_i
。
现在 \text{Mila}
想要知道,这棵苹果树有多少种可能?由于 \text{Mila}
数学不太好,所以请你帮帮她。
输入格式
第一行一个整数 n
,表示苹果树的节点数量。
第二行 n
个整数,第 i
个整数为 a_i
。
下面 n-1
行每行两个整数 x,y
,表示 x,y
节点之间有一条边。
输出格式
输出一行一个整数,表示苹果树可能的方案数,答案对 998244353
取模。
数据范围
对于 20\%
的数据,满足 n\leq 20
。
对于 40\%
的数据,满足 n\leq 100
。
对于另外 10\%
的数据,保证 a_i\leq 1
。
对于另外 10\%
的数据,保证 x=1
。
对于 100\%
的数据,满足 0\leq a_i\leq 2\times 10^3, 1\leq n \leq 2\times 10^3
。
题解
不难想到树形 DP,那么就能够自然地定义出状态:f_{i,j}
表示 i
的子树内放了 j
个苹果的方案数。
转移的时候,考虑 x
的一棵子树 y
,将 y
的信息和 x
之前的儿子的信息合并起来,可以得到:$f'_{x,i}=\sum_{j=0}^{sz_y}f_{y,j}f_{x,i-j}
其中,$sz_y
表示 y
的子树大小。
统计完所有儿子后,将自己的 a_x+1
以上的部分清 0
,因为这部分是不合法的,即让 f_{x,[a_x+1,n]}=0
。
咋一看这个 DP 是 O(n^3)
的,其实不然,这是 \text{dp}
中的经典套路,仔细思考可以发现这其实是 O(n^2)
的。
证明也不难,合并子树时,f_{x,i}
和 f_{y,j}
乘起来贡献 f'_{x,i+j}
,不妨换个角度来看,可以看成 x
之前的子树中第 i
个点和 y
的子树中第 j
个点乘起来做贡献。那么仔细观察,其实每个点只会和其他点乘起来贡献一次,贡献的对象就是他们的 \text{lca}
。所以总共的贡献次数只有 n^2
次。
参考代码
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
using namespace std;
const int NMAX = 2100;
const int MOD = 998244353;
vector<int> E[NMAX];
int limit[NMAX], siz[NMAX];
int dp[NMAX][NMAX];
void solve(int x, int fa) {
siz[x] = 1;
dp[x][0] = 1;
if (limit[x] >= 1) {
dp[x][1] = 1;
}
for (auto y: E[x]) {
if (y == fa) {
continue;
}
solve(y, x);
for (int i = siz[x] + siz[y]; i >= 0; i -= 1) {
if (i > limit[x]) {
dp[x][i] = 0;
continue;
}
int j_init = max(i - siz[x], 1);
int j_lim = min(siz[y], i);
for (int j = j_init; j <= j_lim; j += 1) {
(dp[x][i] += (long long )dp[y][j] * dp[x][i - j] % MOD) %= MOD;
}
}
siz[x] += siz[y];
}
}
int main() {
freopen("appletree.in", "r", stdin);
freopen("appletree.out", "w", stdout);
int n;
scanf("%d", &n);
for (int i = 1; i <= n; i += 1) {
scanf("%d", &limit[i]);
}
for (int i = 1; i < n; i += 1) {
int x, y;
scanf("%d %d", &x, &y);
E[x].push_back(y);
E[y].push_back(x);
}
solve(1, -1);
int ans = 0;
for (int i = 0; i <= limit[1] && i <= n; i += 1) {
(ans += dp[1][i]) %= MOD;
}
printf("%d\n", ans);
exit(0);
}