You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

3.8 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

题目描述

\text{Mila} 种了一棵苹果树,在魔法的浇灌下苹果树茁壮成长,很快就结出了苹果。

但是天有不测风云,一个夜晚,天打五雷轰,苹果树倒了。

在苹果树倒之前,\text{Mila} 忘记记录哪些节点上有苹果了,但是 \text{Mila} 记得,这棵苹果树的根节点是 1,并且每个节点上最多有 1 个苹果。在树倒之前 \text{Mila} 记录下了一些数字,a_i 表示 i 号节点的子树内的苹果数量不超过 a_i

现在 \text{Mila} 想要知道,这棵苹果树有多少种可能?由于 \text{Mila} 数学不太好,所以请你帮帮她。

输入格式

第一行一个整数 n,表示苹果树的节点数量。

第二行 n 个整数,第 i 个整数为 a_i

下面 n-1 行每行两个整数 x,y,表示 x,y 节点之间有一条边。

输出格式

输出一行一个整数,表示苹果树可能的方案数,答案对 998244353 取模。

数据范围

对于 20\% 的数据,满足 n\leq 20

对于 40\% 的数据,满足 n\leq 100

对于另外 10\% 的数据,保证 a_i\leq 1

对于另外 10\% 的数据,保证 x=1

对于 100\% 的数据,满足 0\leq a_i\leq 2\times 10^3, 1\leq n \leq 2\times 10^3

题解

不难想到树形 DP那么就能够自然地定义出状态f_{i,j} 表示 i 的子树内放了 j 个苹果的方案数。

转移的时候,考虑 x 的一棵子树 y,将 y 的信息和 x 之前的儿子的信息合并起来,可以得到:$f'_{x,i}=\sum_{j=0}^{sz_y}f_{y,j}f_{x,i-j}其中,$sz_y 表示 y 的子树大小。

统计完所有儿子后,将自己的 a_x+1 以上的部分清 0,因为这部分是不合法的,即让 f_{x,[a_x+1,n]}=0

咋一看这个 DP 是 O(n^3) 的,其实不然,这是 \text{dp} 中的经典套路,仔细思考可以发现这其实是 O(n^2) 的。

证明也不难,合并子树时,f_{x,i}f_{y,j} 乘起来贡献 f'_{x,i+j},不妨换个角度来看,可以看成 x 之前的子树中第 i 个点和 y 的子树中第 j 个点乘起来做贡献。那么仔细观察,其实每个点只会和其他点乘起来贡献一次,贡献的对象就是他们的 \text{lca}。所以总共的贡献次数只有 n^2 次。

参考代码

#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <vector>
using namespace std;
const int NMAX = 2100;
const int MOD = 998244353;
vector<int> E[NMAX];
int limit[NMAX], siz[NMAX];
int dp[NMAX][NMAX];
void solve(int x, int fa) {
    siz[x] = 1;
    dp[x][0] = 1;
    if (limit[x] >= 1) {
        dp[x][1] = 1;
    }
    for (auto y: E[x]) {
        if (y == fa) {
            continue;
        }
        solve(y, x);
        for (int i = siz[x] + siz[y]; i >= 0; i -= 1) {
            if (i > limit[x]) {
                dp[x][i] = 0;
                continue;
            }
            int j_init = max(i - siz[x], 1);
            int j_lim = min(siz[y], i);
            for (int j = j_init; j <= j_lim; j += 1) {
                (dp[x][i] += (long long )dp[y][j] * dp[x][i - j] % MOD) %= MOD;
            }
        }
        siz[x] += siz[y];
    }
}
int main() {
    freopen("appletree.in", "r", stdin);
    freopen("appletree.out", "w", stdout);
    int n;
    scanf("%d", &n);
    for (int i = 1; i <= n; i += 1) {
        scanf("%d", &limit[i]);
    }
    for (int i = 1; i < n; i += 1) {
        int x, y;
        scanf("%d %d", &x, &y);
        E[x].push_back(y);
        E[y].push_back(x);
    }
    solve(1, -1);
    int ans = 0;
    for (int i = 0; i <= limit[1] && i <= n; i += 1) {
        (ans += dp[1][i]) %= MOD;
    }
    printf("%d\n", ans);
    exit(0);
}