给你一个含有 n 个整数的数组 a , 对于每个 a_{i} 添加一个元素集 \{a_{i}\}

定义一次操作由以下两步构成:

  • 选择两个集合 S, T, 满足 S \cap T = \emptyset

  • 删除集合 S, T,添加集合 S \cup T

之后, 我们构造一个可重集合 M 表示当前剩下的所有集合的大小。

举个例子, 若当前剩下的集合为 \{5\}, \{8\}, \{2,5,12,4\} , 则 M\{1,1,4\} 。现在让你求出所有不同的 M 的数量, 对 998244353 取模.

n \leq 2000,a_i\leq n

尝试以人话写出官方题解的意思。

首先我们定义 cnt_iia 中的出现次数,考虑把所有可能的 M 集合都通过补 0 补成大小为 n 的多重集。我们称一个多重集为好的当且仅当:

  1. \sum\limits_{i = 1}^k M_i = n
  2. \forall 1\leq k\leq n,\sum\limits_{i = 1}^k M_i \leq \sum\limits_{i = 1}^n \min\{k,cnt_i\}

那么我们要做的就是统计”好的“的集合的个数。

我们考虑 DP。首先定义状态 f_{i,j,k} 表示考虑了前 i 个元素,和为 j,最后一个填的为 k 的方案数,可以得到转移:

f_{i,j,k} = \sum\limits_{x \geq k,j \leq \sum\limits_{p = 1}^i \min\{k,cnt_p\}} f_{i-1,j - x,x}

容易用滚动数组以及前缀和分别把空间复杂度和时间复杂度优化到 \mathcal{O}(n^2),\mathcal{O}(n^3),我们考虑继续挖掘性质。

发现 k 实际上有范围,显然由上文的 2 得 n \geq k \times i\rightarrow k\leq \frac{n}{i}。那么 k 的枚举量就显然是一个调和级数了,所以我们得到新的时间复杂度是 \mathcal{O}(n^2\log n),可以通过本题。

//Goodbye goodbye goodbye. You were bigger than the whole sky
//You were more than just a short time.
//And I've got a lot to pine about,I've got a lot to live without.
//I'm never gonna meet.What could've been would've been.
//What should've been you.
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdio>
#include <cctype>
#include <vector>
#include <cmath>
#include <queue>
using namespace std;
#define ll long long
#define FO(x) {freopen(#x".in","r",stdin);freopen(#x".out","w",stdout);}
#define pii pair<int,int>
#define mp make_pair

char buf[1 << 20], *p1, *p2;
#define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, 1 << 20, stdin), p1 == p2)?EOF: *p1++)
template <typename T> inline void read(T &t) {
    int v = getchar();T f = 1;t = 0;
    while (!isdigit(v)) {if (v == '-')f = -1;v = getchar();}
    while (isdigit(v)) {t = t * 10 + v - 48;v = getchar();}
    t *= f;
}
template <typename T,typename... Args> inline void read(T &t,Args&... args) {
    read(t);read(args...);
}

const ll mod = 998244353;
const double eps = 1e-10;
const int N = 2e3 + 10;
int n,a[N],cnt[N],r[N];
int f[2][N][N];

signed main() {
    read(n);
    for (int i = 1;i <= n;++i) read(a[i]),++cnt[a[i]];
    for (int i = 1;i <= n;++i) {
        for (int j = 1;j <= n;++j) {
            r[i] += min(cnt[j],i);
        }
    }
    int now = 0,lst = 1;
    for (int i = 0;i <= n;++i) f[lst][0][i] = 1;
    for (int i = 1;i <= n;++i) {
        //swap(now,lst);
        for (int j = 0;j <= r[i];++j) {
            for (int k = 0;k <= min(r[i] / i,j);++k) {
                f[now][j][k] = f[lst][j - k][k];
            }
        }
        for (int j = 0;j <= r[i];++j) {
            f[lst][j][r[i] / i] = f[now][j][r[i] / i];
            for (int k = (r[i]/i) - 1;k >= 0;--k) {
                f[lst][j][k] = (f[lst][j][k + 1] + f[now][j][k]) % mod;
            }
        }
    }
    int ans = (f[now][n][0] + f[now][n][1]) % mod;
    printf("%d\n",ans);
}