본문 바로가기

알고리즘

FFT (Fast Fourier Transform)

문제

$n-1$차 다항식 2개의 곱을 구하는 경우를 생각해보자. 

$A(x) = a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}, B(x) = b_0+b_1x+b_2x^2+...+b_{n-1}x^{n-1}, C(x)=A(x)B(x)$

이를 naive하게 계산하면 $O(n^2)$이지만, FFT를 이용하면 $O(nlogn)$으로 해결할 수 있다.

아이디어

$m-1$차 다항식은 함숫값 $m$개가 결정되면 유일하게 결정된다. 따라서 $x_0, x_1, ..., x_{2m-1}$에 대해 $C(x_i)=A(x_i)B(x_i)$를 모두 구해주면 $2m-1$차 다항식 $C(x)$를 구할 수 있다. 이제 $A(x_i)$와 $B(x_i)$를 빠르게 구하는 방법을 알아보자.

 

2의 거듭제곱 $n$에 대하여, $n-1$차 이하의 다항식 $A$의 함숫값 $n$개를 구해보자. $\omega_n= e^{2\pi i / n} = \cos(2\pi / n) + i \sin(2\pi / n)$, $x=\omega_n^k$

$$A(x)=(a_0 + a_2x^2+...+a_{n-2}x^{n-2}) + x(a_1+a_3x^2+...+a_{n-1}x^{n-2}) = A_{even}(x^2)+xA_{odd}(x^2)$$

$$A_{even} = a_0+a_2x+...+a_{n-2}x^{n/2-1}, A_{odd}=a_1+a_3x+...+a_{n-1}x^{n/2-1}$$

또한 아래와 같은 성질이 성립한다. ($0\leq k<n/2$)

$$A_{even}((\omega_n^k)^2)=A_{even}(\omega_{n/2}^k), A_{odd}((\omega_n^k)^2)=A_{odd}(\omega_{n/2}^k)$$

따라서 $A_{even}$과 $A_{odd}$의 함숫값 $n/2$개씩을 구해주면 $A$의 함숫값 $n$개를 구할 수 있다. $n-1$차 다항식에서 $n$개의 함숫값을 구하는 시간을 $T(n)$이라 할 경우 $T(n)=2T(n/2)+O(n)$이다. 따라서 $T(n)=O(nlogn)$.

 

$n$개의 함숫값을 빠르게 구하는 방법을 알았으니 이제 이로부터 함수를 복원해내는 방법을 살펴보자.

$y_k=A(\omega ^k)$라 하면 아래와 같이 행렬로 표현할 수 있다.

역행렬로 $a_i$를 구하면 아래와 같다.

행렬이 거의 똑같이 생겼으므로 함수를 복원하는 것도 함숫값을 구할 때와 비슷한 방식으로 할 수 있다.

 

구현

위에서 설명한 알고리즘은 분할정복을 할 때 $a_i$의 $i$가 홀수인 경우와 짝수인 경우로 나누었다. 즉, least significant bit에 따라 분할을 한다. 이를 반복문으로 구현하기는 쉽지 않다 (분할된 구간이 연속되어있지 않기 때문이다). 재귀함수를 이용하면 구현할 수 있지만, 이는 느리기 때문에 권장하지 않는다. 따라서 least significant bit와 most significant를 바꿔주는 (bit를 좌우 반전시키는) 과정이 필요하다. 아래 코드의 첫 번째 for문이 그 내용이다.

 

 

 

아래 링크를 통해 공부하였다.

https://speakerdeck.com/wookayin/fast-fourier-transform-algorithm

 

아래는 FFT의 코드이다. https://github.com/koosaga/olympiad/blob/master/Library/의 코드에서 필요한 부분만 가져온 것이다.

#include <bits/stdc++.h>
using namespace std;

typedef long long lint;

#define sz(v) ((int)(v).size())
#define all(v) (v).begin(), (v).end()
 
namespace fft{
    using real_t = long double;
    using base = complex<real_t>;

    void fft(vector<base> &a, bool inv){
        int n = a.size(), j = 0;
        vector<base> roots(n/2);
        for(int i=1; i<n; i++){
            int bit = (n >> 1);
            while(j >= bit){
                j -= bit;
                bit >>= 1;
            }
            j += bit;
            if(i < j) swap(a[i], a[j]);
        }
        real_t ang = 2 * acos(real_t(-1)) / n * (inv ? -1 : 1);
        for(int i=0; i<n/2; i++){
            roots[i] = base(cos(ang * i), sin(ang * i));
        }
        for(int i=2; i<=n; i<<=1){
            int step = n / i;
            for(int j=0; j<n; j+=i){
                for(int k=0; k<i/2; k++){
                    base u = a[j+k], v = a[j+k+i/2] * roots[step * k];
                    a[j+k] = u+v;
                    a[j+k+i/2] = u-v;
                }
            }
        }
        if(inv) for(int i=0; i<n; i++) a[i] /= n;
    }
    template<typename T>
    vector<T> multiply(vector<T> &v, const vector<T> &w){ // 계수가 v와 w인 함수를 곱한다
        vector<base> fv(all(v)), fw(all(w));
        int n = 2;
        while(n < sz(v) + sz(w)) n <<= 1;
        fv.resize(n); fw.resize(n);
        fft(fv, 0); fft(fw, 0);
        for(int i=0; i<n; i++) fv[i] *= fw[i];
        fft(fv, 1);
        vector<T> ret(n);
        for(int i=0; i<n; i++) ret[i] = (T)llround(fv[i].real());
        return ret;
    }
}

'알고리즘' 카테고리의 다른 글

MCMF (Minimum Cost Maximum Flow)  (0) 2021.06.23
SPFA (Shortest Path Faster Algorithm)  (0) 2021.06.23
Edmond's algorithm (Directed-MST)  (0) 2020.12.30
KMP(Knuth-Morris-Pratt String Matching Algorithm)  (0) 2019.12.12
Suffix Array와 LCP  (0) 2019.10.25