快速傅里叶变换(Fast Fourier Transform),是一种极为著名的算法,以其巧妙、高效著称。在OI领域的应用主要是多项式乘法(特殊形式是高精度乘法)。多项式乘法似乎还与组合数学有很大的联系,所以对于OI选手,是一种非常有用的必备法宝。
为什么FFT做多项式乘法更快呢?FFT原是一种信号处理算法,可以将时域数据转换为频域数据。时域乘积,频域卷积;时域卷积,频域乘积。两者互相对称。多项式乘法是两个序列的卷积,如果将其转换到频域下做乘法,再转换回来,是不是可以更快一点呢?幸运的是,这确实可以。FFT包括DFT(离散傅里叶变换)和IDFT(逆离散傅里叶变换)两部分,两者均可以用的时间复杂度实现。
首先给出DFT的公式:
不要被复杂的数学描述吓到,我们来一步步分解。令,在N确定时,它是一个常数。那么DFT的公式可以改写为
。观察一下,DFT实质上是一个
的系数矩阵与长度为
的向量
的矩阵乘法,得到一个长度为
的向量
。那么根据公式就得到最朴素的
算法。
公式中还有许多神奇的性质。首先普及一个常用的公式:,这个公式解决了虚数指数幂的运算障碍。其次根据公式可以推导得
。下面开始探究
的性质:
性质1:周期性 -
性质2:对称性 - ,证明只需要将式子展开后即可得到。
接下来是FFT的核心步骤:
用上述性质进行进一步化简,可以得到:
这样就把原问题等价转化为两个子问题,是FFT中体现的分治策略。最后是将频域转回时域,IDFT公式:
UOJ上有一道模板题。具体实现如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
#include <set> #include <map> #include <queue> #include <ctime> #include <cmath> #include <cstdio> #include <vector> #include <string> #include <cctype> #include <bitset> #include <cstring> #include <cstdlib> #include <utility> #include <iostream> #include <algorithm> #define lowbit(x) (x)&(-x) #define REP(i,a,b) for(int i=(a);i<=(b);i++) #define PER(i,a,b) for(int i=(a);i>=(b);i--) #define RVC(i,S) for(int i=0;i<(S).size();i++) #define RAL(i,u) for(int i=fr[u];i!=-1;i=e[i].next) using namespace std; typedef long long LL; typedef pair<int,int> pii; template<class T> inline void read(T& num) { bool start=false,neg=false; char c; num=0; while((c=getchar())!=EOF) { if(c=='-') start=neg=true; else if(c>='0' && c<='9') { start=true; num=num*10+c-'0'; } else if(start) break; } if(neg) num=-num; } /*============ Header Template ============*/ struct cpx { double x,y; cpx(double a=0,double b=0):x(a),y(b) {} }; cpx operator + (cpx a,cpx b) {return cpx(a.x+b.x,a.y+b.y);} cpx operator - (cpx a,cpx b) {return cpx(a.x-b.x,a.y-b.y);} cpx operator * (cpx a,cpx b) {return cpx(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);} const double pi=acos(-1.0); const int maxn=(int)(4e5)+5; cpx A[maxn],B[maxn]; void DFT(cpx* a,int n,int d=1) { for(int i=(n>>1),j=1;j<n;j++) { if(i<j) swap(a[i],a[j]); int k;for(k=(n>>1);i&k;i^=k,k>>=1); i^=k; } for(int m=2;m<=n;m<<=1) { cpx w=cpx(cos(pi*2/m*d),sin(pi*2/m*d)); for(int i=0;i<n;i+=m) { cpx s=cpx(1,0); for(int j=i;j<(i+(m>>1));j++) { cpx u=a[j],v=a[j+(m>>1)]*s; a[j]=u+v;a[j+(m>>1)]=u-v; s=s*w; } } } if(d==-1) for(int i=0;i<n;i++) a[i].x=a[i].x/n; } int main() { int n,m,len=0; read(n);read(m); n++;m++; for(int i=0;i<n;i++) read(A[i].x); for(int i=0;i<m;i++) read(B[i].x); while((1<<len) < (max(n,m)<<1)) len++; DFT(A,1<<len);DFT(B,1<<len); for(int i=0;i<(1<<len);i++) A[i]=A[i]*B[i]; DFT(A,1<<len,-1); for(int i=0;i<n+m-1;i++) printf("%d ",(int)trunc(A[i].x+0.5)); printf("\n"); return 0; } |
[BZOJ 2179] 高精度乘法模板~~
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
/************************************************************** Problem: 2179 User: frank_c1 Language: C++ Result: Accepted Time:1188 ms Memory:21296 kb ****************************************************************/ #include <set> #include <map> #include <queue> #include <ctime> #include <cmath> #include <cstdio> #include <vector> #include <string> #include <cctype> #include <bitset> #include <cstring> #include <cstdlib> #include <utility> #include <iostream> #include <algorithm> #define lowbit(x) (x)&(-x) #define REP(i,a,b) for(int i=(a);i<=(b);i++) #define PER(i,a,b) for(int i=(a);i>=(b);i--) #define RVC(i,S) for(int i=0;i<(S).size();i++) #define RAL(i,u) for(int i=fr[u];i!=-1;i=e[i].next) using namespace std; typedef long long LL; typedef pair<int,int> pii; template<class T> inline void read(T& num) { bool start=false,neg=false; char c; num=0; while((c=getchar())!=EOF) { if(c=='-') start=neg=true; else if(c>='0' && c<='9') { start=true; num=num*10+c-'0'; } else if(start) break; } if(neg) num=-num; } /*============ Header Template ============*/ struct cpx { double x,y; cpx(double a=0,double b=0):x(a),y(b) {} }; cpx operator + (cpx a,cpx b) {return cpx(a.x+b.x,a.y+b.y);} cpx operator - (cpx a,cpx b) {return cpx(a.x-b.x,a.y-b.y);} cpx operator * (cpx a,cpx b) {return cpx(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);} const int maxn=500005; const double pi=acos(-1.0); char buf[maxn]; cpx A[maxn],B[maxn]; LL num[maxn]; void DFT(cpx* a,int n,int d=1) { for(int i=(n>>1),j=1;j<n;j++) { if(i<j) swap(a[i],a[j]); int k;for(k=(n>>1);i&k;i^=k,k>>=1); i^=k; } for(int m=2;m<=n;m<<=1) { cpx w=cpx(cos(pi*2/m*d),sin(pi*2/m*d)); for(int i=0;i<n;i+=m) { cpx s=cpx(1,0); for(int j=i;j<(i+(m>>1));j++) { cpx u=a[j],v=s*a[j+(m>>1)]; a[j]=u+v;a[j+(m>>1)]=u-v; s=s*w; } } } if(d==-1) for(int i=0;i<n;i++) a[i].x=a[i].x/n; } int main() { int n; read(n); scanf("%s",buf); for(int i=0;i<n;i++) A[i].x=1.0*(buf[i]-'0'); scanf("%s",buf); for(int i=0;i<n;i++) B[i].x=1.0*(buf[i]-'0'); int len=0; while((1<<len) < (n<<1)) len++; DFT(A,1<<len);DFT(B,1<<len); for(int i=0;i<(1<<len);i++) A[i]=A[i]*B[i]; DFT(A,1<<len,-1); for(int i=((n-1)<<1);i>=0;i--) num[++num[0]]=(LL)trunc(A[i].x+0.5); LL x=0; for(int i=1;;i++,x/=10) { if(i>num[0] && x==0) { num[0]=i-1;break; } x=x+num[i];num[i]=x%10; } for(int i=num[0];i>0;i--) printf("%lld",num[i]); printf("\n"); return 0; } |
[BZOJ 2194] 卷积模板~~
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
/************************************************************** Problem: 2194 User: frank_c1 Language: C++ Result: Accepted Time:1920 ms Memory:16896 kb ****************************************************************/ #include <set> #include <map> #include <queue> #include <ctime> #include <cmath> #include <cstdio> #include <vector> #include <string> #include <cctype> #include <bitset> #include <cstring> #include <cstdlib> #include <utility> #include <iostream> #include <algorithm> #define REP(i,a,b) for(int i=(a);i<=(b);i++) #define PER(i,a,b) for(int i=(a);i>=(b);i--) #define RVC(i,S) for(int i=0;i<(S).size();i++) #define RAL(i,u) for(int i=fr[u];i!=-1;i=e[i].next) using namespace std; typedef long long LL; typedef pair<int,int> pii; template<class T> inline void read(T& num) { bool start=false,neg=false; char c; num=0; while((c=getchar())!=EOF) { if(c=='-') start=neg=true; else if(c>='0' && c<='9') { start=true; num=num*10+c-'0'; } else if(start) break; } if(neg) num=-num; } /*============ Header Template ============*/ struct cpx { double x,y; cpx(double a=0,double b=0):x(a),y(b) {} }; cpx operator + (cpx a,cpx b) {return cpx(a.x+b.x,a.y+b.y);} cpx operator - (cpx a,cpx b) {return cpx(a.x-b.x,a.y-b.y);} cpx operator * (cpx a,cpx b) {return cpx(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x);} const double pi=acos(-1.0); const int maxn=(int)(5e5)+5; cpx A[maxn],B[maxn]; void FFT(cpx* a,int n,int d=1) { for(int i=(n>>1),j=1;j<n;j++) { if(i<j) swap(a[i],a[j]); int k;for(k=(n>>1);i&k;i^=k,k>>=1); i^=k; } for(int m=2;m<=n;m<<=1) { cpx w=cpx(cos(pi*2/m*d),sin(pi*2/m*d)); for(int i=0;i<n;i+=m) { cpx s=cpx(1,0); for(int j=i;j<(i+(m>>1));j++) { cpx u=a[j],v=a[j+(m>>1)]*s; a[j]=u+v;a[j+(m>>1)]=u-v; s=s*w; } } } if(d==-1) for(int i=0;i<n;i++) a[i].x/=n; } int main() { int n,len=0; read(n); while((1<<len) < (n<<1)) len++; for(int i=0;i<n;i++) scanf("%lf%lf",&A[n-i-1],&B[i]); FFT(A,1<<len);FFT(B,1<<len); for(int i=0;i<(1<<len);i++) A[i]=A[i]*B[i]; FFT(A,1<<len,-1); for(int i=0;i<n;i++) printf("%lld\n",(LL)trunc(A[n-i-1].x+0.5)); return 0; } |