FFT 学习笔记

快速傅里叶变换(Fast Fourier Transform,FFT)是一种可在 时间内完成的离散傅里叶变换(Discrete Fourier transform,DFT)算法,在 OI 中的主要应用之一是加速多项式乘法的计算。

定义

多项式

系数表示与点值表示

多项式的系数表示,设 表示一个 次多项式,所有项的系数组成的 维向量 唯一确定了这个多项式。

多项式的点值表示,将一组互不相同 带入多项式,得到的 个值。设它们组成的 维向量分别为

求值与插值

定理:一个 次多项式在 个不同点的取值唯一确定了该多项式。

证明:假设命题不成立,存在两个不同的 次多项式 ,满足对于任何 ,有

,则 也是一个 次多项式。对于任何 ,有

个根,这与代数基本定理(一个 次多项式在复数域上有且仅有 个根)相矛盾,故 并不是一个 次多项式,原命题成立,证毕。

如果我们按照定义求一个多项式的点值表示,时间复杂度为

已知多项式的点值表示,求其系数表示,可以使用插值。朴素的插值算法时间复杂度为

多项式乘法

我们定义两个多项式 相乘的结果为 (假设两个多项式次数相同,若不同可在后面补零)。

两个 次多项式相乘,得到的是一个 次多项式,时间复杂度为

如果使用两个多项式在 个点处取得的点值表示,那么

时间复杂度为

复数

为实数,,形如 的数叫做复数,其中 被称为虚数单位。复数域是已知最大的域。

复平面

在复平面中, 轴代表实数、 轴(除原点外的所有点)代表虚数。每一个复数 对应复平面上一个从 指向 的向量。

该向量的长度 叫做模长。表示从 轴正半轴到该向量的转角的有向(以逆时针为正方向)角叫做幅角。

复数相加遵循平行四边形定则。

复数相乘时,模长相乘,幅角相加。

单位根

下文中,如不特殊指明,均设 的正整数次幂。

在复平面上,以原点为圆心, 为半径作圆,所得的圆叫做单位圆。以原点为起点,单位圆的 等分点为终点,作 个向量。设所得的幅角为正且最小的向量对应的复数为 ,称为 次单位根。

由复数乘法的定义(模长相乘,幅角相加)可知,其与的 个向量对应的复数分别为 ,其中

单位根的幅角为周角的 ,这为我们提供了一个计算单位根及其幂的公式

单位根的性质

性质一:

从几何意义上看,在复平面上,二者表示的向量终点相同。

更具体的,有

性质二:

等式左边相当于 乘上 ,考虑其值

快速傅里叶变换

考虑多项式 的表示。将 次单位根的 次幂带入多项式的系数表示,所得点值向量 称为其系数向量 离散傅里叶变换

按照朴素算法来求离散傅里叶变换,时间复杂度仍然为

考虑将多项式按照系数下标的奇偶分为两部分

则有

假设 ,现在要求

这一步转化利用了单位根的性质一。

对于

这一步转化除性质一外,还利用到了性质二和 这一显然的结论。

注意到,当 取遍 时, 取遍了

也就是说,如果已知 处的点值,就可以在 的时间内求得 处的取值。而关于 的问题都是相对于原问题规模缩小了一半的子问题,分治的边界为一个常数项

根据主定理,该分治算法的时间复杂度为

这就是最常用的 FFT 算法 —— Cooley-Tukey 算法。

傅里叶逆变换

将点值表示的多项式转化为系数表示,同样可以使用快速傅里叶变换,这个过程叫做傅里叶逆变换

的傅里叶变换。考虑另一个向量 满足

即多项式 处的点值表示。

将上式展开,得

考虑一个式子

时,两边同时乘上

两式相减,整理后得

分子为零,分母不为零,所以

时,显然

继续考虑上式

时,,否则 ,即

所以,使用单位根的倒数代替单位根,做一次类似快速傅里叶变换的过程,再将结果每个数除以 ,即为傅里叶逆变换的结果。

实现

C++ 的 STL 在头文件 complex 中提供一个复数的模板实现 std::complex<T>,其中 T 为实数类型,一般取 double,在对精度要求较高的时候可以使用 long double__float128(不常用)。

考虑到单位根的倒数等于其共轭复数,在程序实现中,为了减小误差,通常使用 std::conj() 取得 IDFT 所需的「单位根的倒数」。

递归实现

直接按照上面得到的结论来实现即可,比较直观。

代码
const double PI = acos(-1);

bool inversed = false;

inline std::complex<double> omega(const int n, const int k) {
    if (!inversed) return std::complex<double>(cos(2 * PI / n * k), sin(2 * PI / n * k));
    else return std::conj(std::complex<double>(cos(2 * PI / n * k), sin(2 * PI / n * k)));
}

inline void transform(std::complex<double> *a, const int n) {
    if (n == 1) return;

    static std::complex<double> buf[MAXN];
    const int m = n / 2;
    // 按照系数奇偶划分为两半
    for (int i = 0; i < m; i++) {
        buf[i] = a[i * 2];
        buf[i + m] = a[i * 2 + 1];
    }
    std::copy(buf, buf + n, a);

    // 分治
    std::complex<double> *a1 = a, *a2 = a + m;
    fft(a1, m);
    fft(a2, m);

    // 合并两个子问题
    for (int i = 0; i < m; i++) {
        std::complex<double> x = omega(n, i);
        buf[i] = a1[i] + x * a2[i];
        buf[i + m] = a1[i] - x * a2[i];
    }

    std::copy(buf, buf + n, a);
}

迭代实现

递归实现的 FFT 效率不高,实际中一般采用迭代实现。

二进制位翻转

考虑递归 FFT 分治到边界时,每个数的顺序,及其二进制位。

000 001 010 011 100 101 110 111
 0   1   2   3   4   5   6   7
 0   2   4   6 - 1   3   5   7
 0   4 - 2   6 - 1   5 - 3   7
 0 - 4 - 2 - 6 - 1 - 5 - 3 - 7
000 100 010 110 001 101 011 111

发现规律,分治到边界后的下标等于原下标的二进制位翻转。

代码实现,枚举每个二进制位即可。

int k = 0;
while ((1 << k) < n) k++;
for (int i = 0; i < n; i++) {
    int t = 0;
    for (int j = 0; j < k; j++) if (i & (1 << j)) t |= (1 << (k - j - 1));
    if (i < t) std::swap(a[i], a[t]);
}
蝴蝶操作

考虑合并两个子问题的过程,假设 分别存在 中, 将要被存放在 中,合并的单位操作可表示为

考虑加入一个临时变量 ,使得这个过程可以在原地完成,不需要另一个数组 ,也就是说,将 存放在 中,覆盖原来的值

这一过程被称为蝴蝶操作

代码

omega[k] 中保存 (IDFT 时保存 )。

枚举 ,表示一次要将 长度的序列答案合并为长度为 的,根据单位根的性质一,有

void transform(std::complex<double> *a, const int n, const std::complex<double> *omega) {
    // 此处省略二进制位翻转的代码
    for (int l = 2; l <= n; l *= 2) {
        int m = l / 2;
        // 将两个长度为 m 的序列的答案合并为长度为 l 的序列的答案
        for (std::complex<double> *p = a; p != a + n; p += l) {
            for (int i = 0; i < m; i++) {
                // 蝴蝶操作
                std::complex<double> t = omega[n / l * i] * p[m + i];
                p[m + i] = p[i] - t;
                p[i] += t;
            }
        }
    }
}

模板

需要注意的是,在求两个次数分别为 的多项式的乘积时,需要分别求出其在至少 个点处的点值,因为这样才能保证相乘后的点值能唯一确定一个 次多项式。

struct FastFourierTransform {
    std::complex<double> omega[MAXN], omegaInverse[MAXN];

    void init(const int n) {
        for (int i = 0; i < n; i++) {
            omega[i] = std::complex<double>(cos(2 * PI / n * i), sin(2 * PI / n * i));
            omegaInverse[i] = std::conj(omega[i]);
        }
    }

    void transform(std::complex<double> *a, const int n, const std::complex<double> *omega) {
        int k = 0;
        while ((1 << k) < n) k++;
        for (int i = 0; i < n; i++) {
            int t = 0;
            for (int j = 0; j < k; j++) if (i & (1 << j)) t |= (1 << (k - j - 1));
            if (i < t) std::swap(a[i], a[t]);
        }

        for (int l = 2; l <= n; l *= 2) {
            int m = l / 2;
            for (std::complex<double> *p = a; p != a + n; p += l) {
                for (int i = 0; i < m; i++) {
                    std::complex<double> t = omega[n / l * i] * p[m + i];
                    p[m + i] = p[i] - t;
                    p[i] += t;
                }
            }
        }
    }

    void dft(std::complex<double> *a, const int n) {
        transform(a, n, omega);
    }

    void idft(std::complex<double> *a, const int n) {
        transform(a, n, omegaInverse);
        for (int i = 0; i < n; i++) a[i] /= n;
    }
} fft;

inline void multiply(const int *a1, const int n1, const int *a2, const int n2, int *res) {
    int n = 1;
    while (n < n1 + n2) n *= 2;
    static std::complex<double> c1[MAXN], c2[MAXN];
    for (int i = 0; i < n1; i++) c1[i].real(a1[i]);
    for (int i = 0; i < n2; i++) c2[i].real(a2[i]);
    fft.init(n);
    fft.dft(c1, n), fft.dft(c2, n);
    for (int i = 0; i < n; i++) c1[i] *= c2[i];
    fft.idft(c1, n);
    for (int i = 0; i < n1 + n2 - 1; i++) res[i] = static_cast<int>(floor(c1[i].real() + 0.5));
}

参考资料