题意:有一个数列$f$,对$\forall i\geq2,f_i=2f_{i-1}+3f_{i-2}$,给定$f_0,f_1$,再给定一个集合$S=\{a_{1\cdots n}\}$和$k$,求$\begin{align*}\sum\limits_{\substack{S'\subset S\\|S'|=k}}f\left(\sum\limits_{x\in S'}x\right)\end{align*}$
先看这个数列,它的特征方程为$\lambda^2-2\lambda-3=0$,两个特征根为$\lambda_1=-1,\lambda_2=3$,所以它的通项公式为$f_n=c_1(-1)^n+c_23^n$,由$\begin{cases}c_1+c_2=f_0\\-c_1+3c_2=f_1\end{cases}$我们得到$\begin{cases}c_1=\dfrac{3f_0-f_1}4\\c_2=\dfrac{f_0+f_1}4\end{cases}$
所以我们可以对题目给出的式子进行一番操作:
$\begin{align*}\sum\limits_{\substack{S'\subset S\\|S'|=k}}f\left(\sum\limits_{x\in S'}x\right)&=\sum\limits_{\substack{S'\subset S\\|S'|=k}}c_1(-1)^{\sum\limits_{x\in S'}x}+c_23^{\sum\limits_{x\in S'}x}\\&=c_1\sum\limits_{\substack{S'\subset S\\|S'|=k}}\prod\limits_{x\in S'}(-1)^x+c_2\sum\limits_{\substack{S'\subset S\\|S'|=k}}\prod\limits_{x\in S'}3^x\end{align*}$
这种先抽取定量元素再求乘积的方式很像多项式乘法,事实上,对上式的第一个sigma,它等于$\begin{align*}[x^k]\prod\limits_{i=1}^n\left((-1)^{a_i}x+1\right)\end{align*}$,第二个sigma同理
这个多项式的乘积直接用分治+FFT计算即可,总时间复杂度$O(k\log_2k\log_2n)$
模数比较鬼畜,要用FFT,太久没写我都不知道FFT怎么卡精度了==($n$单位根的$0\cdots n-1$次幂全部预处理出来)
#include#include #include typedef double du;typedef long long ll;const int mod=99991,inv4=24998;int min(int a,int b){return a void swap(C&a,C&b){ C c=a; a=b; b=c;}int pow(int a,int b){ int s=1; while(b){ if(b&1)s=mul(s,a); a=mul(a,a); b>>=1; } return s;}struct complex{ du x,y; complex(du a=0,du b=0){x=a;y=b;}};complex operator+(complex a,complex b){return complex(a.x+b.x,a.y+b.y);}complex operator-(complex a,complex b){return complex(a.x-b.x,a.y-b.y);}complex operator*(complex a,complex b){return complex(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);}int rev[262144],N,iN;complex w[18][262144];void pre(int n){ int i,j,k; for(N=1,k=0;N <<=1)k++; for(i=0;i >1]>>1)|((i&1)<<(k-1)); k=0; for(i=1;i<=N;i<<=1){ for(j=0;j >1;k++){ wi=w[c][k]; if(on==-1)wi.y=-wi.y; t=wi*a[i/2+j+k]; a[i/2+j+k]=a[j+k]-t; a[j+k]=a[j+k]+t; } } c++; } if(on==-1){ for(i=0;i >1; mul(solve(l,mid),solve(mid+1,r),f,min(mid-l+1,k),min(r-mid,k)); } return f;}int a[100010];int main(){ int n,i,f0,f1,c1,c2,ans; scanf("%d%d",&n,&k); for(i=1;i<=n;i++)scanf("%d",a+i); scanf("%d%d",&f0,&f1); c1=mul(3*f0-f1,inv4); c2=mul(f0+f1,inv4); ans=0; for(i=1;i<=n;i++)b[i]=pow(-1,a[i]); ans=(ans+mul(c1,solve(1,n)[k]))%mod; for(i=1;i<=n;i++)b[i]=pow(3,a[i]); ans=(ans+mul(c2,solve(1,n)[k]))%mod; printf("%d",(ans+mod)%mod);}