「Codeforces 755G」PolandBall and Many Other Balls

Codeforces 755G. PolandBall and Many Other Balls

题意

\(n\) 个球编号为 \(1,2,\dotsc,n\),一组可以是一个球 \(\{i\}\) 或者是两个相邻的球 \(\{i,i+1\}\)

对于 \(i=1,2,\dotsc,k\)\(n\) 个球划分成 \(i\) 组的方案数,每个球至多在一个组内,并且可以不在任何一个组内

\(998244353\) 取模

\(n\le 10^9,k\le 2^{15}\)

做法

显然有 \(\mathcal O(nk)\) 的 dp 做法

\(f_{i,j}\) 表示前 \(i\) 个球分成 \(k\) 组的方案数

转移是

\[ f_{i,j}=f_{i-1,j}+f_{i-1,j-1}+f{i-2,j-1} \]

用生成函数表示第二维,有

\[ f_i(x)=f_{i-1}(x)+x\cdot f_{i-1}(x)+x\cdot f_{i-2}(x) \]

矩阵

这个东西可以矩阵转移

\[ \begin{bmatrix} x+1 & x \\ 1 & 0 \end{bmatrix} \begin{pmatrix} f_i(x) \\ f_{i-1}(x) \end{pmatrix} = \begin{pmatrix} f_{i+1}(x) \\ f_i(x) \end{pmatrix} \]

快速幂一下就好了

复杂度 \(\mathcal O(k\log k\log n)\)

倍增

上述算法写得不好就被卡常了,直接倍增常数小得多

大概要讨论一下最中间的几个球的情况,没写

复杂度还是 \(\mathcal O(k\log k\log n)\)

通项

写出递推关于 \(z\) 特征方程

\[ z^2=(x+1)z+x \]

其中系数是多项式

解得

\[ \begin{align} z_0&=\frac{x+1-\sqrt{x^2+6x+1}}{2} \\ z_1&=\frac{x+1+\sqrt{x^2+6x+1}}{2} \end{align} \]

因此 \(f_n(x)\) 可以被表示为 \(A z_0^n+B z_1^n\)

根据

\[ \begin{cases} f_0(x) = A+B =1 \\ f_1(x) = A z_0 + B z_1 = x+1 \end{cases} \]

解得

\[ \begin{cases} A=\frac{-x-1+\sqrt{x^2+6x+1}}{2\sqrt{x^2+6x+1}} \\ B=\frac{x+1+\sqrt{x^2+6x+1}}{2\sqrt{x^2+6x+1}} \end{cases} \]

事实上由于 \(z_0\)\(A\) 的常数项为 \(0\)\(A z_0^n\) 对答案的前 \(n\) 项没有影响,而答案除了前 \(n\) 项都为 \(0\),可以特判后忽略这部分

于是

\[ \begin{align} f_n(x) &\equiv B z_1^n &\pmod{x^{n+1}} \\ &= \frac{z_1^{n+1}}{\sqrt{x^2+6x+1}} & \end{align} \]

直接做复杂度也是 \(\mathcal O(k\log k\log n)\),常数想必更小

由于这里 \(z_1\) 的常数项为 \(1\),可以方便地用 \(\ln\)\(\exp\) 计算

总复杂度 \(\mathcal O(k\log k)\)

代码

矩乘

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<string.h>
#include<cmath>
#include<vector>

using namespace std;
#define ll long long

const int N = 1<<15, M = N<<1, P = 998244353;
int n, k, w[M];
vector<int> ans;
inline int Pow(ll x, int y=P-2){
int ans=1;
for(; y; y>>=1, x=x*x%P) if(y&1) ans=ans*x%P;
return ans;
}
inline void DFT(vector<int> &f, int n){
static unsigned ll F[M];
for(int i=0, j=0; i<n; ++i){
F[i]=f[j];
for(int k=n>>1; (j^=k)<k; k>>=1);
}
for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1) for(int k=j; k<j+i; ++k){
int t=F[k+i]*w[i+k-j]%P;
F[k+i]=F[k]+P-t, F[k]+=t;
}
for(int i=0; i<n; ++i) f[i]=F[i]%P;
}
inline void IDFT(vector<int> &f, int n){
reverse(f.begin()+1, f.end()), DFT(f, n);
for(int i=0, I=Pow(n); i<n; ++i) f[i]=(ll)f[i]*I%P;
}
inline int Calc(int x){ int ans=1; while(ans<=x) ans<<=1; return ans;}
inline vector<int> operator *(const vector<int> &x, const vector<int> &y){
if(!x.size() || !y.size()){
vector<int> a;
return a.push_back(0), a;
}
vector<int> a=x, b=y;
int n=Calc(x.size()+y.size()-2);
a.resize(n), b.resize(n), DFT(a, n), DFT(b, n);
for(int i=0; i<n; ++i) a[i]=(ll)a[i]*b[i]%P;
IDFT(a, n);
return a.resize(min((int)x.size()+(int)y.size()-1, k+1)), a;
}
inline void operator *=(vector<int> &x, const vector<int> &y){ x=x*y;}
inline vector<int> operator +(const vector<int> &x, const vector<int> &y){
vector<int> ans=x;
if(y.size()>x.size()) ans.resize(y.size());
for(unsigned i=0; i<y.size(); ++i) (ans[i]+=y[i])%=P;//
return ans;
}
inline void operator +=(vector<int> &x, const vector<int> &y){ x=x+y;}
struct matrix{
vector<int> a[2][2];
inline matrix operator *(const matrix &rhs)const{
matrix ans;
for(int i=0; i<2; ++i) for(int k=0; k<2; ++k) for(int j=0; j<2; ++j)
ans.a[i][j]+=a[i][k]*rhs.a[k][j];
return ans;
}
} A;
inline matrix Pow(matrix x, int y){
matrix ans=x;
--y;
for(; y; y>>=1, x=x*x) if(y&1) ans=ans*x;
return ans;
}
int main() {
for(int i=1; i<M; i<<=1){
w[i]=1, w[i+1]=Pow(3, (P-1)/i/2);
for(int j=2; j<i; ++j) w[i+j]=(ll)w[i+j-1]*w[i+1]%P;
}
scanf("%d%d", &n, &k);
A.a[0][0]={1, 1}, A.a[0][1]={0, 1}, A.a[1][0]={1};
A=Pow(A, n), ans=A.a[1][0]*vector<int>{1,1}+A.a[1][1];
for(int i=1; i<=k; ++i) printf("%d ", i>n?0:ans[i]);
return 0;
}

通项

纯手写

刚刚卡了下常数还是没最快= =

数组迭代实现好快= =

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
#include<cstdio>
#include<algorithm>
#include<cctype>
#include<string.h>
#include<cmath>
#include<vector>

using namespace std;
#define ll long long

const int OUT_LEN = 1000000;
char obuf[OUT_LEN], *ooh=obuf;
inline void print(char c) {
if (ooh==obuf+OUT_LEN) fwrite(obuf, 1, OUT_LEN, stdout), ooh=obuf;
*ooh++=c;
}
template<class T>
inline void print(T x) {
static int buf[30], cnt;
if (x==0) print('0');
else {
if (x<0) print('-'), x=-x;
for (cnt=0; x; x/=10) buf[++cnt]=x%10+48;
while(cnt) print((char)buf[cnt--]);
}
}
inline void flush() { fwrite(obuf, 1, ooh - obuf, stdout); }

const int N = 1<<15, M = N<<1, P = 998244353;
int n, k, w[M], inv[M];
inline int Pow(ll x, int y=P-2){
int ans=1;
for(; y; y>>=1, x=x*x%P) if(y&1) ans=ans*x%P;
return ans;
}
inline void DFT(vector<int> &f, int n){
static unsigned ll F[M];
for(int i=0, j=0; i<n; ++i){
F[i]=f[j];
for(int k=n>>1; (j^=k)<k; k>>=1);
}
for(int i=1; i<n; i<<=1) for(int j=0; j<n; j+=i<<1){
int *W=w+i;
unsigned ll *F0=F+j, *F1=F+j+i;
for(int k=j; k<j+i; ++k, ++W, ++F0, ++F1){
int t=*F1**W%P;
*F1=*F0+P-t, *F0+=t;
}
}
for(int i=0; i<n; ++i) f[i]=F[i]%P;
}
inline void IDFT(vector<int> &f, int n){
reverse(f.begin()+1, f.end()), DFT(f, n);
for(int i=0, I=Pow(n); i<n; ++i) f[i]=(ll)f[i]*I%P;
}
inline int Calc(int x){ int ans=1; while(ans<=x) ans<<=1; return ans;}
inline vector<int> operator *(const vector<int> &x, const vector<int> &y){
if(!x.size() || !y.size()) return {0};
if((unsigned ll)x.size()*y.size()<=1<<8){
vector<int> ans(x.size()+y.size()-1);
for(unsigned i=0; i<x.size(); ++i) for(unsigned j=0; j<y.size(); ++j)
ans[i+j]=(ans[i+j]+(ll)x[i]*y[j])%P;
return ans;
}
vector<int> a=x, b=y;
int n=Calc(x.size()+y.size()-2);
a.resize(n), b.resize(n), DFT(a, n), DFT(b, n);
for(int i=0; i<n; ++i) a[i]=(ll)a[i]*b[i]%P;
IDFT(a, n);
return a.resize(min((int)x.size()+(int)y.size()-1, k+1)), a;
}
inline void operator *=(vector<int> &x, const vector<int> &y){ x=x*y;}
inline vector<int> operator +(const vector<int> &x, const vector<int> &y){
vector<int> ans=x;
if(y.size()>x.size()) ans.resize(y.size());
for(unsigned i=0; i<y.size(); ++i) (ans[i]+=y[i])%=P;
return ans;
}
inline void operator +=(vector<int> &x, const vector<int> &y){ x=x+y;}
inline vector<int> operator -(const vector<int> &x, const vector<int> &y){
vector<int> ans=x;
if(y.size()>x.size()) ans.resize(y.size());
for(unsigned i=0; i<y.size(); ++i) (ans[i]+=P-y[i])%=P;
return ans;
}
inline vector<int> PolyDiv2(const vector<int> &x){
vector<int> ans(x.size());
for(unsigned i=0; i<x.size(); ++i) ans[i]=(x[i]&1?x[i]+P:x[i])>>1;
return ans;
}
inline vector<int> Ext(const vector<int> &a, int n){
if(n<=(int)a.size()) return vector<int>(a.begin(), a.begin()+n);
vector<int> ans=a;
return ans.resize(n), ans;
}
vector<int> PolyInv(const vector<int> &a, int n=-1){
if(n==-1) n=a.size();
if(n==1) return {Pow(a[0])};
vector<int> ans=PolyInv(a, (n+1)/2), tmp=Ext(a, n);
int m=Calc(n*2-1);
ans.resize(m), tmp.resize(m), DFT(ans, m), DFT(tmp, m);
for(int i=0; i<m; ++i) ans[i]=(2+(ll)(P-tmp[i])*ans[i])%P*ans[i]%P;
IDFT(ans, m);
return ans.resize(n), ans;
}
vector<int> PolySqrt(const vector<int> &a, int n=-1){
if(n==-1) n=a.size();
if(n==1) return {1};
vector<int> ans=PolySqrt(a, (n+1)/2);
return PolyDiv2(Ext(ans+Ext(a, n)*PolyInv(Ext(ans, n)), n));
}
inline vector<int> D(const vector<int> &a){
vector<int> ans(a.size()-1);
for(unsigned i=1; i<a.size(); ++i) ans[i-1]=(ll)a[i]*i%P;
return ans;
}
inline vector<int> Int(const vector<int> &a){
vector<int> ans(a.size()+1);
for(unsigned i=0; i<a.size(); ++i) ans[i+1]=(ll)a[i]*inv[i+1]%P;
return ans;
}
inline vector<int> PolyLn(const vector<int> &a){
return Int(Ext(D(a)*PolyInv(a), a.size()-1));
}
vector<int> PolyExp(const vector<int> &a, int n=-1){
if(n==-1) n=a.size();
if(n==1) return {1};
vector<int> ans=PolyExp(a, (n+1)/2);
return Ext(ans*(Ext(a, n)-PolyLn(Ext(ans, n))+vector<int>{1}), n);
}
int main() {
for(int i=1; i<M; i<<=1){
w[i]=1, w[i+1]=Pow(3, (P-1)/i/2);
for(int j=2; j<i; ++j) w[i+j]=(ll)w[i+j-1]*w[i+1]%P;
}
inv[1]=1;
for(int i=2; i<M; ++i) inv[i]=(ll)(P-P/i)*inv[P%i]%P;
scanf("%d%d", &n, &k);
vector<int> s={1, 6, 1}, c;
s.resize(k+1), s=PolySqrt(s);
c=PolyLn(PolyDiv2(vector<int>{1, 1}+s));
for(int &i:c) i=(ll)i*(n+1)%P;
c=PolyExp(c)*PolyInv(s);
for(int i=1; i<=k; ++i) (i>n?print('0'):print(c[i])), print(' ');
return flush(), 0;
}

「Codeforces 755G」PolandBall and Many Other Balls

https://cekavis.github.io/codeforces-775g/

Author

Cekavis

Posted on

2019-02-26

Updated on

2022-06-16

Licensed under

Comments