Code:
#include<bits/stdc++.h> #define setIO(s) freopen(s".in","r",stdin), freopen(s".out","w",stdout) #define ll long long #define maxn 500000 using namespace std; const double pi=acos(-1.0); struct cpx { double x,y; cpx(double a=0,double b=0) {x=a,y=b; } cpx operator+(const cpx b) { return cpx(x+b.x,y+b.y); } cpx operator-(const cpx b) { return cpx(x-b.x,y-b.y); } cpx operator*(const cpx b) { return cpx(x*b.x-y*b.y,x*b.y+y*b.x); } }A[maxn],B[maxn],C[maxn],D[maxn]; void FFT(cpx *a,int n,int flag) { for(int i=0,k=0;i<n;++i) { if(i>k) swap(a[i],a[k]); for(int j=(n>>1);(k^=j)<j;j>>=1); } for(int mid=1;mid<n;mid<<=1) { cpx wn(cos(pi/mid), flag*sin(pi/mid)), x,y; for(int i=0;i<n;i+=(mid<<1)) { cpx w(1,0); for(int j=0;j<mid;++j) { x=a[i+j], y=w*a[i+j+mid]; a[i+j]=x+y, a[i+j+mid]=x-y; w=w*wn; } } } if(flag==-1) for(int i=0;i<n;++i) a[i].x/=(double) n; } int Max; int arr[maxn],f[maxn],g[maxn*2],h[maxn*3]; ll answer[maxn]; inline void solve1() { for(int i=1;i<=Max;++i) answer[i]+=f[i]; } inline void solve2() { int len=1; for(int i=1;i<=Max;++i) A[i].x=f[i]; for(;len<=(Max<<1);len<<=1); FFT(A,len,1); for(int i=0;i<len;++i) A[i]=A[i]*A[i]; FFT(A,len,-1); for(int i=0;i<len;++i) answer[i]+=((ll)(A[i].x+0.5)-g[i])/2; for(int i=0;i<len;++i) A[i].x=A[i].y=0; } inline void solve3() { int len=1; for(;len<=(Max*3);len<<=1); for(int i=0;i<=Max;++i) A[i].x=f[i]; for(int i=0;i<=Max*3;++i) C[i].x=A[i].x, D[i].x=g[i]; FFT(A, len, 1); for(int i=0;i<len;++i) A[i]=A[i]*A[i]*A[i]; FFT(A, len, -1); FFT(C, len, 1), FFT(D, len, 1); for(int i=0;i<len;++i) C[i]=C[i]*D[i]; FFT(C, len, -1); for(int i=0;i<len;++i) answer[i]+=(ll)((ll)(A[i].x+0.5)-3*((ll)(C[i].x+0.5)-1ll*h[i])-1ll*h[i])/6; } int main() { // setIO("input"); int n; scanf("%d",&n); for(int i=1;i<=n;++i) { scanf("%d",&arr[i]); ++f[arr[i]], ++g[arr[i]<<1], ++h[arr[i]*3]; Max=max(Max, arr[i]); } solve1(),solve2(), solve3(); for(int i=0;i<3*Max;++i) if(answer[i]) printf("%d %lld\n",i, answer[i]); return 0; }
原文:https://www.cnblogs.com/guangheli/p/11168867.html