树形DP的转移是一个卷积的转移形式
可以先链剖,一个点的轻儿子先合并,然后一条重链用分治FFT合并
复制代码
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
using namespace std;
typedef vector<int> poly;
const int N=800010,P=998244353;
inline char nc(){
static char buf[100000],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++;
}
inline void read(int &x){
char c=nc(); x=0;
for(;c>'9'||c<'0';c=nc());for(;c>='0'&&c<='9';x=x*10+c-'0',c=nc());
}
int n,m,cnt,a[N],G[N],son[N],size[N],p[N],fa[N],t;
struct edge{
int t,nx;
}E[N<<2];
poly f[N],g[N];
int num,w[2][N],rev[N];
inline void addedge(int x,int y){
E[++cnt].t=y; E[cnt].nx=G[x]; G[x]=cnt;
E[++cnt].t=x; E[cnt].nx=G[y]; G[y]=cnt;
}
void pfs(int x,int f){
size[x]=1; p[++t]=x; fa[x]=f;
for(int i=G[x];i;i=E[i].nx)
if(E[i].t!=f){
pfs(E[i].t,x);
size[x]+=size[E[i].t];
if(size[E[i].t]>size[son[x]]) son[x]=E[i].t;
}
}
inline int Pow(int x,int y){
int ret=1;
for(;y;y>>=1,x=1LL*x*x%P) if(y&1) ret=1LL*x*ret%P;
return ret;
}
inline void Pre(const int &n){
num=n; int g=Pow(3,(P-1)/n);
w[0][0]=w[1][0]=1;
for(int i=1;i<n;i++) w[1][i]=1LL*w[1][i-1]*g%P;
for(int i=1;i<n;i++) w[0][i]=w[1][n-i];
}
inline void NTT(int *a,int n,int r){
for(int i=1;i<n;i++) if(rev[i]>i) swap(a[i],a[rev[i]]);
for(int i=1;i<n;i<<=1)
for(int j=0;j<n;j+=(i<<1))
for(int k=0;k<i;k++){
int x=a[j+k],y=1LL*a[j+k+i]*w[r][num/(i<<1)*k]%P;
a[j+k]=(x+y)%P; a[j+k+i]=(x-y+P)%P;
}
if(!r) for(int i=0,inv=Pow(n,P-2);i<n;i++) a[i]=1LL*a[i]*inv%P;
}
poly operator *(poly a,poly b){
if(!a.size() || !b.size()) return a.size()?b:a;
poly ret;
if(a.size()+b.size()<500){
ret.resize(a.size()+b.size()-1);
for(int i=0;i<a.size();i++)
for(int j=0;j<b.size();j++)
ret[i+j]=(ret[i+j]+1LL*a[i]*b[j])%P;
return ret;
}
int n,L=0;
for(n=1;n<=a.size()+b.size();n<<=1,L++); L--;
for(int i=1;i<n;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<L);
static int tmpa[N],tmpb[N];
for(int i=0;i<a.size();i++) tmpa[i]=a[i];
for(int i=0;i<b.size();i++) tmpb[i]=b[i];
NTT(tmpa,n,1); NTT(tmpb,n,1);
for(int i=0;i<n;i++) tmpa[i]=1LL*tmpa[i]*tmpb[i]%P;
NTT(tmpa,n,0); ret.resize(a.size()+b.size()-1);
for(int i=0;i<ret.size();i++) ret[i]=tmpa[i];
for(int i=0;i<n;i++) tmpa[i]=tmpb[i]=0;
return ret;
}
poly operator +(poly a,poly b){
poly ret; ret.resize(max(a.size(),b.size()));
for(int i=0;i<a.size();i++) ret[i]=a[i];
for(int i=0;i<b.size();i++) ret[i]=(ret[i]+b[i])%P;
return ret;
}
struct polyc{
poly a00,a01,a10,a11;
polyc(){}
polyc(poly a,poly b):a00(a),a11(b){}
int size(){ return max(max(a00.size(),a01.size()),max(a10.size(),a11.size())); }
friend polyc operator *(polyc a,polyc b){
polyc ret;
ret.a00=a.a00*b.a00+a.a01*b.a00+a.a00*b.a10;
ret.a01=a.a00*b.a01+a.a00*b.a11+a.a01*b.a01;
ret.a10=a.a10*b.a00+a.a10*b.a10+a.a11*b.a00;
ret.a11=a.a10*b.a01+a.a10*b.a11+a.a11*b.a01;
return ret;
}
friend bool operator <(polyc a,polyc b){
return a.size()>b.size();
}
};
struct polypair{
poly a,b;
polypair(){}
polypair(poly _a,poly _b):a(_a),b(_b){}
friend polypair operator *(polypair a,polypair b){
return polypair(a.a*b.a,a.b*b.b);
}
friend bool operator <(polypair a,polypair b){
return a.a.size()>b.a.size();
}
};
namespace HuffmanFFT{
priority_queue<polypair> a;
void Push(poly _a,poly _b){
a.push(polypair(_a,_b));
}
polypair work(){
while(a.size()>1){
polypair A=a.top(); a.pop();
polypair B=a.top(); a.pop();
a.push(A*B);
}
polypair ret=a.top(); a.pop();
return ret;
}
}
namespace DivAndConq{
vector<polyc> a;
void Push(poly _a,poly _b){ a.push_back(polyc(_b,_a)); }
void Clear(){ a.clear(); }
polyc solve(int l=0,int r=a.size()-1){
if(l==r) return a[l];
int mid=l+r>>1;
return solve(l,mid)*solve(mid+1,r);
}
}
inline void solve(int x){
DivAndConq::Clear();
for(int u=x;u;u=son[u]){
for(int i=G[u];i;i=E[i].nx)
if(E[i].t!=fa[u] && E[i].t!=son[u])
HuffmanFFT::Push(f[E[i].t]+g[E[i].t],g[E[i].t]);
polypair cur; if(HuffmanFFT::a.size()) cur=HuffmanFFT::work();
poly U; U.push_back(0); U.push_back(a[u]);
if(cur.b.size()) cur.b=cur.b*U; else cur.b=U;
if(!cur.a.size()) cur.a.push_back(1);
DivAndConq::Push(cur.b,cur.a);
}
polyc cur=DivAndConq::solve();
f[x]=cur.a10+cur.a11; g[x]=cur.a00+cur.a01;
}
int main(){
read(n); read(m);
int _m; for(_m=1;_m<=n;_m<<=1); Pre(_m);
for(int i=1;i<=n;i++) read(a[i]);
for(int i=1,x,y;i<n;i++)
read(x),read(y),addedge(x,y);
pfs(1,0);
for(int i=t;i;i--)
if(son[fa[p[i]]]!=p[i]) solve(p[i]);
poly ans=f[1]+g[1];
if(ans.size()>m) printf("%dn",ans[m]);
else puts("0");
return 0;
}
最后
以上就是愉快樱桃最近收集整理的关于[链剖 FFT] LOJ#6289. 花朵的全部内容,更多相关[链剖内容请搜索靠谱客的其他文章。
本图文内容来源于网友提供,作为学习参考使用,或来自网络收集整理,版权属于原作者所有。
发表评论 取消回复