Menci

眉眼如初,岁月如故

在那无法确定的未来
只愿真心如现在一般清澈


FFT 学习笔记

快速傅里叶变换(Fast Fourier Transform,FFT)是一种可在 O(nlogn) O(n \log n) 时间内完成的离散傅里叶变换(Discrete Fourier transform,DFT)算法,在 OI 中的主要应用之一是加速多项式乘法的计算。

定义

多项式

系数表示与点值表示

多项式的系数表示,设 A(x) A(x) 表示一个 n1 n - 1 次多项式,所有项的系数组成的 n n 维向量 唯一确定了这个多项式。

多项式的点值表示,将一组互不相同n n x x 带入多项式,得到的 n n 个值。设它们组成的 n n 维向量分别为

求值与插值

定理:一个 n1 n - 1 次多项式在 n n 个不同点的取值唯一确定了该多项式。

证明:假设命题不成立,存在两个不同的 n1 n - 1 次多项式 A(x) A(x) B(x) B(x) ,满足对于任何 i[0, n1] i \in [0,\ n - 1] ,有 A(xi)=B(xi) A(x_i) = B(x_i)

C(x)=A(x)B(x) C(x) = A(x) - B(x) ,则 C(x) C(x) 也是一个 n1 n - 1 次多项式。对于任何 i[0, n1] i \in [0,\ n - 1] ,有 C(xi)=0 C(x_i) = 0

C(x) C(x) n n 个根,这与代数基本定理(一个 n1 n - 1 次多项式在复数域上有且仅有 n1 n - 1 个根)相矛盾,故 C(x) C(x) 并不是一个 n1 n - 1 次多项式,原命题成立,证毕。

如果我们按照定义求一个多项式的点值表示,时间复杂度为 O(n2) O(n ^ 2)

已知多项式的点值表示,求其系数表示,可以使用插值。朴素的插值算法时间复杂度为 O(n2) O(n ^ 2)

多项式乘法

我们定义两个多项式 A(x)=i=0n1aixi A(x) = \sum\limits_{i = 0} ^ {n - 1} a_i x ^ i B(x)=i=0n1bixi B(x) = \sum\limits_{i = 0} ^ {n - 1} b_i x ^ i 相乘的结果为 C(x) C(x) (假设两个多项式次数相同,若不同可在后面补零)。

C(x)=A(x)×B(x)=k=02n2(k=i+jaibj)xk C(x) = A(x) \times B(x) = \sum\limits_{k = 0} ^ {2n - 2} (\sum\limits_{k = i + j} a_i b_j) x ^ k

两个 n1 n - 1 次多项式相乘,得到的是一个 2n2 2n - 2 次多项式,时间复杂度为 O(n2) O(n ^ 2)

如果使用两个多项式在 2n1 2n - 1 个点处取得的点值表示,那么

y3i=(j=02n1ajxij)×(j=02n1bjxij)=y1i×y2i {y_3}_i = (\sum\limits_{j = 0} ^ {2n - 1} a_j x_i ^ j) \times (\sum\limits_{j = 0} ^ {2n - 1} b_j x_i ^ j) = {y_1}_i \times {y_2}_i

时间复杂度为 O(n) O(n)

复数

a a b b 为实数,i2=1 i ^ 2 = -1 ,形如 a+bi a + bi 的数叫做复数,其中 i i 被称为虚数单位。复数域是已知最大的域。

复平面

在复平面中,x x 轴代表实数、y y 轴(除原点外的所有点)代表虚数。每一个复数 a+bi a + bi 对应复平面上一个从 (0, 0) (0,\ 0) 指向 (a, b) (a,\ b) 的向量。

该向量的长度 a2+b2 \sqrt {a ^ 2 + b ^ 2} 叫做模长。表示从 x x 轴正半轴到该向量的转角的有向(以逆时针为正方向)角叫做幅角。

复数相加遵循平行四边形定则。

复数相乘时,模长相乘,幅角相加。

单位根

下文中,如不特殊指明,均设 n n 2 2 的正整数次幂。

在复平面上,以原点为圆心,1 1 为半径作圆,所得的圆叫做单位圆。以原点为起点,单位圆的 n n 等分点为终点,作 n n 个向量。设所得的幅角为正且最小的向量对应的复数为 ωn \omega_n ,称为 n n 次单位根。

由复数乘法的定义(模长相乘,幅角相加)可知,其与的 n1 n - 1 个向量对应的复数分别为 ,其中 ωnn=ωn0=1 \omega_n ^ n = \omega_n ^ 0 = 1

单位根的幅角为周角的 1n 1 \over n ,这为我们提供了一个计算单位根及其幂的公式

ωnk=cosk2πn+isink2πn \omega_n ^ k = \cos k \frac{2 \pi}{n} + i\sin k \frac{2 \pi}{n}

单位根的性质

性质一:ω2n2k=ωnk \omega_{2n} ^ {2k} = \omega_n ^ k

从几何意义上看,在复平面上,二者表示的向量终点相同。

更具体的,有

cos2k2π2n+isin2k2π2n=cosk2πn+isink2πn \cos 2k \frac{2 \pi}{2n} + i\sin 2k \frac{2 \pi}{2n} = \cos k \frac{2 \pi}{n} + i\sin k \frac{2 \pi}{n}

性质二:ωnk+n2=ωnk \omega_n ^ { k + \frac{n}{2} } = -\omega_n ^ k

等式左边相当于 ωnk \omega_n ^ k 乘上 ωnn2 \omega_n ^ { \frac{n}{2} } ,考虑其值

快速傅里叶变换

考虑多项式 A(x) A(x) 的表示。将 n n 次单位根的 0 0 n1 n - 1 次幂带入多项式的系数表示,所得点值向量 称为其系数向量 离散傅里叶变换

按照朴素算法来求离散傅里叶变换,时间复杂度仍然为 O(n2) O(n ^ 2)

考虑将多项式按照系数下标的奇偶分为两部分

则有

A(x)=A1(x2)+xA2(x2) A(x) = A_1(x ^ 2) + x A_2(x ^ 2)

假设 k<n2 k < \frac{n}{2} ,现在要求 A(ωnk) A(\omega_n ^ k)

这一步转化利用了单位根的性质一。

对于 A(ωnk+n2) A(\omega_n ^ {k + \frac{n}{2}})

这一步转化除性质一外,还利用到了性质二和 ωnn=1 \omega_n ^ n = 1 这一显然的结论。

注意到,当 k k 取遍 [0, n21] [0,\ \frac{n}{2} - 1] 时,k k k+n2 k + \frac{n}{2} 取遍了 [0, n1] [0,\ n - 1]

也就是说,如果已知 A1(x) A_1(x) A2(x) A_2(x) 处的点值,就可以在 O(n) O(n) 的时间内求得 A(x) A(x) 处的取值。而关于 A1(x) A_1(x) A2(x) A_2(x) 的问题都是相对于原问题规模缩小了一半的子问题,分治的边界为一个常数项 a0 a_0

根据主定理,该分治算法的时间复杂度为

T(n)=2T(n2)+O(n)=O(nlogn) T(n) = 2T( \frac{n}{2} ) + O(n) = O(n \log n)

这就是最常用的 FFT 算法 —— Cooley-Tukey 算法。

傅里叶逆变换

将点值表示的多项式转化为系数表示,同样可以使用快速傅里叶变换,这个过程叫做傅里叶逆变换

的傅里叶变换。考虑另一个向量 满足

ck=i=0n1yi(ωnk)i c_k = \sum\limits_{i = 0} ^ {n - 1} y_i (\omega_n ^ {-k}) ^ i

即多项式 处的点值表示。

将上式展开,得

考虑一个式子

k0 k \neq 0 时,两边同时乘上 ωnk \omega_n ^ k

两式相减,整理后得

分子为零,分母不为零,所以

S(ωnk)=0 S(\omega_n ^ k) = 0

k=0 k = 0 时,显然 S(ωnk)=n S(\omega_n ^ k) = n

继续考虑上式

j=k j = k 时,S(ωnjk)=n S(\omega_n ^ {j - k}) = n ,否则 S(ωnjk)=0 S(\omega_n ^ {j - k}) = 0 ,即

所以,使用单位根的倒数代替单位根,做一次类似快速傅里叶变换的过程,再将结果每个数除以 n n ,即为傅里叶逆变换的结果。

实现

C++ 的 STL 在头文件 complex 中提供一个复数的模板实现 std::complex<T>,其中 T 为实数类型,一般取 double,在对精度要求较高的时候可以使用 long double__float128(不常用)。

考虑到单位根的倒数等于其共轭复数,在程序实现中,为了减小误差,通常使用 std::conj() 取得 IDFT 所需的「单位根的倒数」。

递归实现

直接按照上面得到的结论来实现即可,比较直观。

代码
const double PI = acos(-1);

bool inversed = false;

inline std::complex<double> omega(const int n, const int k) {
    if (!inversed) return std::complex<double>(cos(2 * PI / n * k), sin(2 * PI / n * k));
    else return std::conj(std::complex<double>(cos(2 * PI / n * k), sin(2 * PI / n * k)));
}

inline void transform(std::complex<double> *a, const int n) {
    if (n == 1) return;

    static std::complex<double> buf[MAXN];
    const int m = n / 2;
    // 按照系数奇偶划分为两半
    for (int i = 0; i < m; i++) {
        buf[i] = a[i * 2];
        buf[i + m] = a[i * 2 + 1];
    }
    std::copy(buf, buf + n, a);

    // 分治
    std::complex<double> *a1 = a, *a2 = a + m;
    fft(a1, m);
    fft(a2, m);

    // 合并两个子问题
    for (int i = 0; i < m; i++) {
        std::complex<double> x = omega(n, i);
        buf[i] = a1[i] + x * a2[i];
        buf[i + m] = a1[i] - x * a2[i];
    }

    std::copy(buf, buf + n, a);
}

迭代实现

递归实现的 FFT 效率不高,实际中一般采用迭代实现。

二进制位翻转

考虑递归 FFT 分治到边界时,每个数的顺序,及其二进制位。

000 001 010 011 100 101 110 111
 0   1   2   3   4   5   6   7
 0   2   4   6 - 1   3   5   7
 0   4 - 2   6 - 1   5 - 3   7
 0 - 4 - 2 - 6 - 1 - 5 - 3 - 7
000 100 010 110 001 101 011 111

发现规律,分治到边界后的下标等于原下标的二进制位翻转。

代码实现,枚举每个二进制位即可。

int k = 0;
while ((1 << k) < n) k++;
for (int i = 0; i < n; i++) {
    int t = 0;
    for (int j = 0; j < k; j++) if (i & (1 << j)) t |= (1 << (k - j - 1));
    if (i < t) std::swap(a[i], a[t]);
}
蝴蝶操作

考虑合并两个子问题的过程,假设 A1(ωn2k) A_1(\omega_{ \frac{n}{2} } ^ k) A2(ωn2k) A_2(\omega_{ \frac{n}{2} } ^ k) 分别存在 a(k) a(k) a(n2+k) a(\frac{n}{2} + k) 中,A(ωnk) A(\omega_n ^ {k}) A(ωnk+n2) A(\omega_n ^ {k + \frac{n}{2} }) 将要被存放在 b(k) b(k) b(n2+k) b(\frac{n}{2} + k) 中,合并的单位操作可表示为

考虑加入一个临时变量 t t ,使得这个过程可以在原地完成,不需要另一个数组 b b ,也就是说,将 A(ωnk) A(\omega_n ^ {k}) A(ωnk+n2) A(\omega_n ^ {k + \frac{n}{2} }) 存放在 a(k) a(k) a(n2+k) a(\frac{n}{2} + k) 中,覆盖原来的值

这一过程被称为蝴蝶操作

代码

omega[k] 中保存 ωnk \omega_n ^ k (IDFT 时保存 ωnk \omega_n ^ {-k} )。

枚举 l l ,表示一次要将 l2 \frac{l}{2} 长度的序列答案合并为长度为 l l 的,根据单位根的性质一,有 ωlk=ωnnlk \omega_l ^ k = \omega_n ^ { \frac{n}{l} k }

void transform(std::complex<double> *a, const int n, const std::complex<double> *omega) {
    // 此处省略二进制位翻转的代码
    for (int l = 2; l <= n; l *= 2) {
        int m = l / 2;
        // 将两个长度为 m 的序列的答案合并为长度为 l 的序列的答案
        for (std::complex<double> *p = a; p != a + n; p += l) {
            for (int i = 0; i < m; i++) {
                // 蝴蝶操作
                std::complex<double> t = omega[n / l * i] * p[m + i];
                p[m + i] = p[i] - t;
                p[i] += t;
            }
        }
    }
}

模板

需要注意的是,在求两个次数分别为 n11 n_1 - 1 n21 n_2 - 1 的多项式的乘积时,需要分别求出其在至少 n1+n21 n_1 + n_2 - 1 个点处的点值,因为这样才能保证相乘后的点值能唯一确定一个 n1+n22 n_1 + n_2 - 2 次多项式。

struct FastFourierTransform {
    std::complex<double> omega[MAXN], omegaInverse[MAXN];

    void init(const int n) {
        for (int i = 0; i < n; i++) {
            omega[i] = std::complex<double>(cos(2 * PI / n * i), sin(2 * PI / n * i));
            omegaInverse[i] = std::conj(omega[i]);
        }
    }

    void transform(std::complex<double> *a, const int n, const std::complex<double> *omega) {
        int k = 0;
        while ((1 << k) < n) k++;
        for (int i = 0; i < n; i++) {
            int t = 0;
            for (int j = 0; j < k; j++) if (i & (1 << j)) t |= (1 << (k - j - 1));
            if (i < t) std::swap(a[i], a[t]);
        }

        for (int l = 2; l <= n; l *= 2) {
            int m = l / 2;
            for (std::complex<double> *p = a; p != a + n; p += l) {
                for (int i = 0; i < m; i++) {
                    std::complex<double> t = omega[n / l * i] * p[m + i];
                    p[m + i] = p[i] - t;
                    p[i] += t;
                }
            }
        }
    }

    void dft(std::complex<double> *a, const int n) {
        transform(a, n, omega);
    }

    void idft(std::complex<double> *a, const int n) {
        transform(a, n, omegaInverse);
        for (int i = 0; i < n; i++) a[i] /= n;
    }
} fft;

inline void multiply(const int *a1, const int n1, const int *a2, const int n2, int *res) {
    int n = 1;
    while (n < n1 + n2) n *= 2;
    static std::complex<double> c1[MAXN], c2[MAXN];
    for (int i = 0; i < n1; i++) c1[i].real(a1[i]);
    for (int i = 0; i < n2; i++) c2[i].real(a2[i]);
    fft.init(n);
    fft.dft(c1, n), fft.dft(c2, n);
    for (int i = 0; i < n; i++) c1[i] *= c2[i];
    fft.idft(c1, n);
    for (int i = 0; i < n1 + n2 - 1; i++) res[i] = static_cast<int>(floor(c1[i].real() + 0.5));
}

参考资料