「BZOJ 3744」Gty 的妹子序列 - 分块 + 树状数组

给一个序列,每次询问 [l, r] 的逆序对数,强制在线。

链接

BZOJ 3744

题解

分块,建立 \sqrt n 个树状数组,第 i 个维护从第一块到第 i 块内的数,每个树状数组 b 中维护每个数的出现次数,便于统计 [l, r] 块之前大于或小于 x 的数的数量。

预处理 [l, r] 块之间的答案或暴力计算答案时,维护一个树状数组,每加入一个数时更新树状数组。

询问跨过多个块时,假设跨过了(完整的) [\text{lb} + 1, \text{rb} - 1] 块,我们要将不在块内的部分加入到块内,具体可以加入到 \text{rb} - 1 块的树状数组中,最后再撤销。

代码

#include <cstdio>
#include <cmath>
#include <cassert>
#include <algorithm>

const int MAXN = 50000;
const int MAXN_SQRT = 223 + 1; // Math.sqrt(50000) = 223.60679774997897

/*
extern struct SegT *null;

struct SegT {
    SegT *lc, *rc;
    int cnt;

    SegT() : lc(this), rc(this), cnt(0) {}
    SegT(SegT *lc, SegT *rc) : lc(lc), rc(rc), cnt(lc->cnt + rc->cnt) {}
    SegT(int cnt) : lc(null), rc(null), cnt(cnt) {}

    SegT *update(int l, int r, int pos, int delta) {
        if (l == r) return new SegT(cnt + delta);
        else {
            int mid = l + (r - l) / 2;
            if (pos > mid) return new SegT(lc, rc->update(mid + 1, r, pos, delta));
            else return new SegT(lc->update(l, mid, pos, delta), rc);
        }
    }

    int queryPrefix(int l, int r, int pos) {
        if (l == r) return cnt;
        else {
            int mid = l + (r - l) / 2;
            if (pos > mid) return lc->cnt + rc->queryPrefix(mid + 1, r, pos);
            else return lc->queryPrefix(l, mid, pos);
        }
    }
} *rt[MAXN + 1], *null = new SegT;
*/

int n, blockSize, blockCnt, a[MAXN + 1];
long long blockAns[MAXN_SQRT + 1][MAXN_SQRT + 1];

inline void discrete() {
    static int set[MAXN + 1];
    std::copy(a + 1, a + n + 1, set + 1);
    std::sort(set + 1, set + n + 1);
    int *end = std::unique(set + 1, set + n + 1);
    for (int i = 1; i <= n; i++) a[i] = std::lower_bound(set + 1, end, a[i]) - set;
}

inline int blockID(int pos) {
    return (pos - 1) / blockSize + 1;
}

inline void getBlock(int bid, int &l, int &r) {
    l = (bid - 1) * blockSize + 1;
    r = std::min(n, l + blockSize - 1);
}

inline int getBlockL(int bid) {
    return (bid - 1) * blockSize + 1;
}

inline int getBlockR(int bid) {
    return std::min(n, getBlockL(bid) + blockSize - 1);
}

/*
inline void validateBlocks() {
    for (int i = 1; i <= blockCnt; i++) {
        int l, r;
        getBlock(i, l, r);
        printf("block[%d] = [%d, %d]\n", i, l, r);
        for (int j = l; j <= r; j++) {
            assert(blockID(j) == i);
        }
    }
}
*/

/*
inline void buildSegT() {
    rt[0] = null;
    for (int i = 1; i <= n; i++) {
        rt[i] = rt[i - 1]->update(1, n, a[i], 1);
    }
}

inline int queryLessThan(int l, int r, int x) {
    if (l > r) return 0;
    return rt[r]->queryPrefix(1, n, x - 1) - rt[l - 1]->queryPrefix(1, n, x - 1);
}

inline int queryGreaterThan(int l, int r, int x) {
    if (l > r) return 0;
    return (r - l + 1) - queryLessThan(l, r, x + 1);
}
*/

struct BinaryIndexedTree {
    int a[MAXN + 1];

    static int lowbit(int x) {
        return x & -x;
    }

    void update(int pos, int delta) {
        for (int i = pos; i <= n; i += lowbit(i)) a[i] += delta;
        // a[pos] += delta;
    }

    int query(int pos) {
        int res = 0;
        for (int i = pos; i > 0; i -= lowbit(i)) res += a[i];
        // for (int i = pos; i > 0; i--) res += a[i];
        return res;
    }

    int query(int l, int r) {
        if (l > r) return 0;
        // printf("%d\n", query(r) - query(l - 1));
        return query(r) - query(l - 1);
    }

    void clear() {
        for (int i = 1; i <= n; i++) a[i] = 0;
    }
} bits[MAXN_SQRT + 1], bit;

inline void buildBIT() {
    for (int i = 1; i <= blockCnt; i++) {
        bits[i] = bits[i - 1];

        int l, r;
        getBlock(i, l, r);

        for (int j = l; j <= r; j++) bits[i].update(a[j], 1);
    }
}

inline void prepare() {
    discrete();
    buildBIT();
    // buildSegT();

    for (int i = 1; i <= blockCnt; i++) {
        int l0 = getBlockL(i);

        bit.clear();

        for (int j = i; j <= blockCnt; j++) {
            blockAns[i][j] = blockAns[i][j - 1];

            int l, r;
            getBlock(j, l, r);

            for (int k = l; k <= r; k++) {
                // blockAns[i][j] += queryGreaterThan(l0, k - 1, a[k]);
                blockAns[i][j] += bit.query(a[k] + 1, n);
                // assert(bit.query(a[k] + 1, n) == queryGreaterThan(l0, k - 1, a[k]));
                bit.update(a[k], 1);
            }
        }
    }
}

inline long long force(int l, int r) {
    long long ans = 0;
    bit.clear();
    for (int i = l; i <= r; i++) {
        ans += bit.query(a[i] + 1, n);
        bit.update(a[i], 1);
        // ans += queryGreaterThan(l, i - 1, a[i]);
    }
    return ans;
}

inline long long query(int l, int r) {
    int lb = blockID(l), rb = blockID(r);

    if (rb - lb <= 1) return force(l, r);

    long long ans = blockAns[lb + 1][rb - 1];

    int lbr = getBlockR(lb), rbl = getBlockL(rb);
    for (int i = lbr; i >= l; i--) {
        bits[rb - 1].update(a[i], 1);
        ans += bits[rb - 1].query(1, a[i] - 1) - bits[lb].query(1, a[i] - 1);
        // ans += queryLessThan(i + 1, rbl - 1, a[i]);
    }

    for (int i = rbl; i <= r; i++) {
        bits[rb - 1].update(a[i], 1);
        ans += bits[rb - 1].query(a[i] + 1, n) - bits[lb].query(a[i] + 1, n);
        // ans += queryGreaterThan(l, i - 1, a[i]);
    }

    for (int i = rbl; i <= r; i++) bits[rb - 1].update(a[i], -1);
    for (int i = lbr; i >= l; i--) bits[rb - 1].update(a[i], -1);

    return ans;
}

int main() {
    scanf("%d", &n);

    blockSize = ceil(sqrt(n));
    blockCnt = n / blockSize + (n % blockSize != 0);

    // validateBlocks();

    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);

    prepare();

    int m;
    scanf("%d", &m);

    int lastAns = 0;
    while (m--) {
        int l, r;
        scanf("%d %d", &l, &r);
        l ^= lastAns, r ^= lastAns;
        printf("%d\n", lastAns = query(l, r));
    }
}