문제
$n-1$차 다항식 2개의 곱을 구하는 경우를 생각해보자.
$A(x) = a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}, B(x) = b_0+b_1x+b_2x^2+...+b_{n-1}x^{n-1}, C(x)=A(x)B(x)$
이를 naive하게 계산하면 $O(n^2)$이지만, FFT를 이용하면 $O(nlogn)$으로 해결할 수 있다.
아이디어
$m-1$차 다항식은 함숫값 $m$개가 결정되면 유일하게 결정된다. 따라서 $x_0, x_1, ..., x_{2m-1}$에 대해 $C(x_i)=A(x_i)B(x_i)$를 모두 구해주면 $2m-1$차 다항식 $C(x)$를 구할 수 있다. 이제 $A(x_i)$와 $B(x_i)$를 빠르게 구하는 방법을 알아보자.
2의 거듭제곱 $n$에 대하여, $n-1$차 이하의 다항식 $A$의 함숫값 $n$개를 구해보자. $\omega_n= e^{2\pi i / n} = \cos(2\pi / n) + i \sin(2\pi / n)$, $x=\omega_n^k$
$$A(x)=(a_0 + a_2x^2+...+a_{n-2}x^{n-2}) + x(a_1+a_3x^2+...+a_{n-1}x^{n-2}) = A_{even}(x^2)+xA_{odd}(x^2)$$
$$A_{even} = a_0+a_2x+...+a_{n-2}x^{n/2-1}, A_{odd}=a_1+a_3x+...+a_{n-1}x^{n/2-1}$$
또한 아래와 같은 성질이 성립한다. ($0\leq k<n/2$)
$$A_{even}((\omega_n^k)^2)=A_{even}(\omega_{n/2}^k), A_{odd}((\omega_n^k)^2)=A_{odd}(\omega_{n/2}^k)$$
따라서 $A_{even}$과 $A_{odd}$의 함숫값 $n/2$개씩을 구해주면 $A$의 함숫값 $n$개를 구할 수 있다. $n-1$차 다항식에서 $n$개의 함숫값을 구하는 시간을 $T(n)$이라 할 경우 $T(n)=2T(n/2)+O(n)$이다. 따라서 $T(n)=O(nlogn)$.
$n$개의 함숫값을 빠르게 구하는 방법을 알았으니 이제 이로부터 함수를 복원해내는 방법을 살펴보자.
$y_k=A(\omega ^k)$라 하면 아래와 같이 행렬로 표현할 수 있다.
역행렬로 $a_i$를 구하면 아래와 같다.
행렬이 거의 똑같이 생겼으므로 함수를 복원하는 것도 함숫값을 구할 때와 비슷한 방식으로 할 수 있다.
구현
위에서 설명한 알고리즘은 분할정복을 할 때 $a_i$의 $i$가 홀수인 경우와 짝수인 경우로 나누었다. 즉, least significant bit에 따라 분할을 한다. 이를 반복문으로 구현하기는 쉽지 않다 (분할된 구간이 연속되어있지 않기 때문이다). 재귀함수를 이용하면 구현할 수 있지만, 이는 느리기 때문에 권장하지 않는다. 따라서 least significant bit와 most significant를 바꿔주는 (bit를 좌우 반전시키는) 과정이 필요하다. 아래 코드의 첫 번째 for문이 그 내용이다.
아래 링크를 통해 공부하였다.
https://speakerdeck.com/wookayin/fast-fourier-transform-algorithm
아래는 FFT의 코드이다. https://github.com/koosaga/olympiad/blob/master/Library/의 코드에서 필요한 부분만 가져온 것이다.
#include <bits/stdc++.h>
using namespace std;
typedef long long lint;
#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
namespace fft{
using real_t = long double;
using base = complex<real_t>;
void fft(vector<base> &a, bool inv){
int n = a.size(), j = 0;
vector<base> roots(n/2);
for(int i=1; i<n; i++){
int bit = (n >> 1);
while(j >= bit){
j -= bit;
bit >>= 1;
}
j += bit;
if(i < j) swap(a[i], a[j]);
}
real_t ang = 2 * acos(real_t(-1)) / n * (inv ? -1 : 1);
for(int i=0; i<n/2; i++){
roots[i] = base(cos(ang * i), sin(ang * i));
}
for(int i=2; i<=n; i<<=1){
int step = n / i;
for(int j=0; j<n; j+=i){
for(int k=0; k<i/2; k++){
base u = a[j+k], v = a[j+k+i/2] * roots[step * k];
a[j+k] = u+v;
a[j+k+i/2] = u-v;
}
}
}
if(inv) for(int i=0; i<n; i++) a[i] /= n;
}
template<typename T>
vector<T> multiply(vector<T> &v, const vector<T> &w){ // 계수가 v와 w인 함수를 곱한다
vector<base> fv(all(v)), fw(all(w));
int n = 2;
while(n < sz(v) + sz(w)) n <<= 1;
fv.resize(n); fw.resize(n);
fft(fv, 0); fft(fw, 0);
for(int i=0; i<n; i++) fv[i] *= fw[i];
fft(fv, 1);
vector<T> ret(n);
for(int i=0; i<n; i++) ret[i] = (T)llround(fv[i].real());
return ret;
}
}
'알고리즘' 카테고리의 다른 글
MCMF (Minimum Cost Maximum Flow) (0) | 2021.06.23 |
---|---|
SPFA (Shortest Path Faster Algorithm) (0) | 2021.06.23 |
Edmond's algorithm (Directed-MST) (0) | 2020.12.30 |
KMP(Knuth-Morris-Pratt String Matching Algorithm) (0) | 2019.12.12 |
Suffix Array와 LCP (0) | 2019.10.25 |