「BZOJ 2194」快速傅立叶之二 - FFT

给定两个长度为 n 的序列 A B ,求长度为 n 的序列 C ,满足 C_k = \sum\limits_{i = k} ^ {n - 1} A_i \times B_{i - k}

链接

BZOJ 2194

题解

A 反转得到 A' ,则有

\begin{align*} C_k &= \sum\limits_{i = k} ^ {n - 1} A_i \times B_{i - k} \\ &= \sum\limits_{i = k} ^ {n - 1} A'_{n - i - 1} \times B_{i - k} \end{align*}

调整求和指标,使 i 0 开始枚举

C_k = \sum\limits_{i = 0} ^ {n - k - 1} A'_{n - k - 1 - i} \times B_{i}

现在表达式已经类似多项式乘法的形式了,考虑这样一个多项式乘法

c_t = \sum\limits_{i = 0} ^ {t} a_{t - i} b_{i}

t = n - k - 1 ,则上式化为

C_{n - t - 1} = \sum\limits_{i = 0} ^ {t} A'_{t - i} B_{i}

考虑将 C 翻转为 C'

C'_{t} = \sum\limits_{i = 0} ^ {t} A'_{t - i} B_{i}

至此,原式已被化为多项式乘法的形式,只需将 A' B 作为多项式系数,求出 C' (只需前 n 项),翻转后即为 C

总时间复杂度为 FFT 的时间复杂度, O(n \log n)

代码

#include <cstdio>
#include <algorithm>
#include <complex>

const int MAXN = 131072 * 2;
const double PI = acos(-1);

struct FastFourierTransform {
    std::complex<double> omega[MAXN], omegaInverse[MAXN];

    void init(const int n) {
        for (int i = 0; i < n; i++) {
            omega[i] = std::complex<double>(cos(2 * PI / n * i), sin(2 * PI / n * i));
            omegaInverse[i] = std::conj(omega[i]);
        }
    }

    void transform(std::complex<double> *a, const int n, const std::complex<double> *omega) {
        int k = 0;
        while ((1 << k) != n) k++;
        for (int i = 0; i < n; i++) {
            int t = 0;
            for (int j = 0; j < k; j++) if (i & (1 << (k - j - 1))) t |= (1 << j);
            if (t > i) std::swap(a[i], a[t]);
        }

        for (int l = 2; l <= n; l *= 2) {
            const int m = l / 2;
            for (std::complex<double> *p = a; p != a + n; p += l) {
                for (int i = 0; i < m; i++) {
                    const std::complex<double> t = omega[n / l * i] * p[i + m];
                    p[i + m] = p[i] - t;
                    p[i] += t;
                }
            }
        }
    }

    void dft(std::complex<double> *a, const int n) {
        transform(a, n, omega);
    }

    void idft(std::complex<double> *a, const int n) {
        transform(a, n, omegaInverse);
        for (int i = 0; i < n; i++) a[i] /= n;
    }

    void operator()(const int *a, const int n1, const int *b, const int n2, int *c) {
        for (int k = 0; k < n1 + n2; k++) {
            for (int i = 0; i <= k; i++) {
                c[k] += a[i] * b[k - i];
            }
        }

        return;

        static std::complex<double> ca[MAXN], cb[MAXN];

        int n = 1;
        while (n < (n1 + n2)) n *= 2;

        for (int i = 0; i < n; i++) ca[i] = std::complex<double>(i < n1 ? a[i] : 0, 0);
        for (int i = 0; i < n; i++) cb[i] = std::complex<double>(i < n2 ? b[i] : 0, 0);

        init(n);
        dft(ca, n);
        dft(cb, n);

        for (int i = 0; i < n; i++) ca[i] *= cb[i];

        idft(ca, n);

        for (int i = 0; i < n; i++) c[i] = static_cast<int>(floor(ca[i].real() + 0.5));
    }
} fft;

inline void force(const int *a, const int *b, int *c, const int n) {
    /*
    for (int k = 0; k < n; k++) {
        for (int i = k; i < n; i++) {
            c[k] += a[n - i - 1] * b[i - k];
        }
    }

    return;
    */

    for (int k = 0; k < n; k++) {
        for (int i = 0; i < n - k; i++) {
            c[k] += a[n - k - 1 - i] * b[i];
        }
    }
}



int main() {
    int n;
    scanf("%d", &n);

    static int a[MAXN], b[MAXN], c[MAXN];
    for (int i = 0; i < n; i++) scanf("%d %d", &a[n - i - 1], &b[i]);

    // force(a, b, c, n);
    fft(a, n, b, n, c);

    for (int i = 0; i < n; i++) printf("%d\n", c[n - i - 1]);

    return 0;
}