「HNOI2016」序列 - 莫队 + RMQ

给定长度为 的序列:,记为 。类似的,)是指序列:。若 ,则称 的子序列。

现在有 个询问,每个询问给定两个数 ,求 的不同子序列的最小值之和。

链接

BZOJ 4540

题解

Sengxian 的题解

对于无修改的区间询问,我们可以将操作离线,采用莫队算法解决。

为便于叙述,定义 为位置 处的元素(即 ); 为左端点属于 ,右端点为 的所有子序列。

已知区间 的答案,考虑新加入的元素 对询问的影响。新元素加入后,产生了 个子序列,它们是

中存在一个最小值 ,使得 的最小值均为 ,举个例子

位置 1 2 3 4 5 6 7 8 9 10
5 7 3 1 5 2 8 3 6 4
标记

区间 的最小值均为 ,即

考虑剩下的 个子序列,从 向左走,经过的所有比 的元素,以这些元素的位置为左端点, 为右端点的所有子序列的最小值均为 。直到到达第一个比 小的元素(例子中的 ),其位置记做

位置 1 2 3 4 5 6 7 8 9 10
5 7 3 1 5 2 8 3 6 4
标记

右边有 个子序列,他们的最小值均为 ,这些子序列对答案的贡献为

仿照刚才的做法,继续向左找第一个小于 的元素,其值为 ,即

位置 1 2 3 4 5 6 7 8 9 10
5 7 3 1 5 2 8 3 6 4
标记

右边有 个子序列,他们的最小值均为 ,这些子序列对答案的贡献为

继续向左找,找到第一个比 小的元素,其值为 ,即 ,注意此时已经找到整个区间内的最小值,左端点在 及其左侧的所有子序列对答案的贡献已经被考虑过,恰好只剩下 个子序列,其最小值均为 ,这些子序列对答案的贡献为

若直接使用上述算法计算每个元素的贡献,单次计算的时间复杂度为 ,超时。

定义 为区间 的最小值所在的位置,设 为区间 内除最小值之外的其它值对答案的贡献,则有

注意到整个式子除边界条件中的 之外,和 是无关的。

为从 位置一直向左跳,直到跳到所有元素中的最小值,用上述方法计算出的贡献总和,则有

注意到刚才的例子中,,即最后一跳的位置最小值的右边一个位置向左跳跳到的位置相同。所以

使用单调栈算法(保持栈底到栈顶的元素单调递增)预处理出每个 ,之后可以在 的时间内递推出 。区间向左扩展时同理,向右边跳即可。

而对于 RMQ,使用稀疏表在 的时间内预处理后,即可在 的时间内回答每次查询。最终,每次转移的时间降为 ,总时间复杂度为

代码

#include <cstdio>
#include <cmath>
// #include <cassert>
#include <algorithm>
#include <stack>

const int MAXN = 100000;
const int MAXLOGN = 17; // log(100000, 2) = 16.609640474436812
const int MAXM = 100000;

struct Element {
    int val;
    Element *left, *right;
    long long sumLeft, sumRight;

    bool operator<(const Element &x) const { return val < x.val; }
    bool operator<=(const Element &x) const { return val <= x.val; }
} a[MAXN];

int n, m, logTable[MAXN + 1];
Element *st[MAXN][MAXLOGN + 1];
long long ans[MAXN];

struct Query {
    int l, r;
    long long *ans;

    bool operator<(const Query &x) const {
        static int blockSize = floor(sqrt(n));
        if (l / blockSize == x.l / blockSize) return r < x.r;
        else return l / blockSize < x.l / blockSize;
    }
} Q[MAXM];

inline Element *min(Element *const a, Element *const b) {
    if (!a) return b;
    if (!b) return a;
    return *a < *b ? a : b;
}

inline void sparseTable() {
    for (int i = 0; i < n - 1; i++) st[i][0] = min(&a[i], &a[i + 1]);
    st[n - 1][0] = &a[n - 1];

    for (int j = 1; (1 << j) <= n; j++) {
        for (int i = 0; i < n; i++) {
            if (i + (1 << (j - 1)) < n) {
                st[i][j] = min(st[i][j - 1], st[i + (1 << (j - 1))][j - 1]);
            }
        }
    }

    for (int i = 0; i <= n; i++) {
        logTable[i] = floor(log2(i));
    }
}

inline Element *rmq(const int l, const int r) {
    if (l == r) return &a[l];
    else {
        int t = logTable[r - l];
        return min(st[l][t], st[r - (1 << t)][t]);
    }
}

inline void prepare() {
    std::stack<Element *> s;
    s.push(&a[0]);

    for (int i = 1; i < n; i++) {
        while (!s.empty() && a[i] <= *s.top()) s.pop();
        if (!s.empty()) a[i].left = s.top();
        else a[i].left = NULL;
        s.push(&a[i]);
    }

    for (int i = 0; i < n; i++) {
        Element *x = &a[i];
        if (x->left == NULL) {
            x->sumLeft = 0;
        } else {
            x->sumLeft = x->left->sumLeft + (x - x->left) * static_cast<long long>(x->val);
        }
    }

    s.push(&a[n - 1]);
    for (int i = n - 2; i >= 0; i--) {
        while (!s.empty() && a[i] <= *s.top()) s.pop();
        if (!s.empty()) a[i].right = s.top();
        else a[i].right = NULL;
        s.push(&a[i]);
    }

    for (int i = n - 1; i >= 0; i--) {
        Element *x = &a[i];
        if (x->right == NULL) {
            x->sumRight = 0;
        } else {
            x->sumRight = x->right->sumRight + (x->right - x) * static_cast<long long>(x->val);
        }
    }

    sparseTable();
    std::sort(Q, Q + m);

    for (int i = 0; i < n; i++) {
        // printf("%lld %lld\n", a[i].sumLeft, a[i].sumRight);
        // printf("%d: sumLeft = %lld, sumRight = %lld, ", a[i].val, a[i].sumLeft, a[i].sumRight);
        // if (a[i].left == NULL) printf("left = NULL, ");
        // else printf("left = %ld[%d], ", a[i].left - a, a[i].left->val);
        // if (a[i].right == NULL) printf("right = NULL\n");
        // else printf("right = %ld[%d]\n", a[i].right - a, a[i].right->val);
    }
}

inline long long expandRight(const int l, const int r) {
    // printf("[%d, %d]\n", l, r);
    Element *pos = rmq(l, r);
    return (pos - &a[l] + 1) * static_cast<long long>(pos->val)
          + a[r].sumLeft - pos->sumLeft;
}

inline long long expandLeft(const int l, const int r) {
    // printf("[%d, %d]\n", l, r);
    Element *pos = rmq(l, r);
    return (&a[r] - pos + 1) * static_cast<long long>(pos->val)
          + a[l].sumRight - pos->sumRight;
}

inline void mo() {
    int l = 0, r = 0;
    long long ans = a[0].val;
    for (int i = 0; i < m; i++) {
        const Query &q = Q[i];
        // assert(l <= r);
        while (r < q.r) r++, ans += expandRight(l, r);
        while (l > q.l) l--, ans += expandLeft(l, r);
        while (r > q.r) ans -= expandRight(l, r), r--;
        while (l < q.l) ans -= expandLeft(l, r), l++;
        *q.ans = ans;
    }
}

int main() {
    scanf("%d %d", &n, &m);
    for (int i = 0; i < n; i++) scanf("%d", &a[i].val);
    for (int i = 0; i < m; i++) {
        scanf("%d %d", &Q[i].l, &Q[i].r);
        Q[i].l--, Q[i].r--;
        Q[i].ans = &ans[i];
    }

    prepare();
    mo();

    for (int i = 0; i < m; i++) printf("%lld\n", ans[i]);

    return 0;
}