「Codeforces 650D」Zip-line - 树状数组 + LIS

给一个序列,每次求原序列的第 i i 个数修改为 x x 时的最长严格上升子序列长度。

链接

Codeforces 650D

题解

首先,预处理整个序列的 LIS 长度和以第 i i 个数结尾的 LIS 长度 li l_i 、以第 i i 个数开头的 LIS 长度 ri r_i

查询时,考虑求包含被修改的数的 LIS 长度。假设被修改的是第 i i 个数,那么找到其左边比新值 x x 小的数对应的 lj l_j 的最大值,加上其右边比新值 x x 大的数对应的 rj r_j 的最大值,再加上 1 1 (即把三段拼起来)。如果求得的数比原 LIS 长度更大,则即为答案,否则考虑被修改的数是否一定在 LIS 中
若被修改的数可以不在 LIS 中,则答案一定不会比原 LIS 长度更小。
若被修改的数一定在 LIS 中,则答案最多减小 1 1 (从原 LIS 中去掉这个数)。

现在的问题是,如何求出哪些数一定在 LIS 中。

  • 首先,用类似上文中的方法分别求出包含原序列中第 i i 个数的 LIS 长度,如果该长度等于原序列的 LIS 长度,则第 i i 个数可以在 LIS 中
  • 对于一个可以在 LIS 中的数 i i ,如果 i i 左边有比 i i 更大的数可以在 LIS 中,则它可以不在 LIS 中
  • 对于一个可以在 LIS 中的数 i i ,如果 i i 右边有比 i i 更小的数可以在 LIS 中,则它可以不在 LIS 中

「对于 i[1,k] i \in [1, k] ai<x a_i < x i i ,求 bi b_i 的最值」可在离散化后用离线加树状数组解决。

代码

#include <cstdio>
#include <algorithm>

const int MAXN = 400000;
const int MAXL = MAXN + MAXN;

struct Question {
    int pos, newVal, *ans;

    bool operator<(const Question &other) const {
        return pos < other.pos;
    }
} q[MAXN + 1];

int n, m, a[MAXN + 1], lisLeft[MAXN + 1], lisRight[MAXN + 1], required[MAXN + 1], initAns, cnt;

inline int processLis(int *a, int *lis) {
    static int tmp[MAXN + 1];
    int max = 0;
    for (int i = 1; i <= n; i++) {
        int x = std::lower_bound(tmp + 1, tmp + max + 1, a[i]) - tmp;
        if (x == max + 1) tmp[++max] = a[i];
        else tmp[x] = std::min(tmp[x], a[i]);

        lis[i] = x;
    }

    return max;
}

inline void discrete() {
    static int tmp[MAXL + 1];
    for (int i = 1; i <= n; i++) tmp[++cnt] = a[i];
    for (int i = 1; i <= m; i++) tmp[++cnt] = q[i].newVal;

    std::sort(tmp + 1, tmp + cnt + 1);
    int *end = std::unique(tmp + 1, tmp + cnt + 1);

    for (int i = 1; i <= n; i++) a[i] = std::lower_bound(tmp + 1, end, a[i]) - tmp;
    for (int i = 1; i <= m; i++) q[i].newVal = std::lower_bound(tmp + 1, end, q[i].newVal) - tmp;
}

struct BinaryIndexedTree {
    int a[MAXL + 1], n;

    static int lowbit(int x) {
        return x & -x;
    }

    void init(int n) {
        this->n = n;
        std::fill(a + 1, a + n + 1, 0);
    }

    int query(int pos) {
        int res = 0;
        for (int i = pos; i > 0; i -= lowbit(i)) res = std::max(res, a[i]);
        return res;
    }

    void update(int pos, int x) {
        for (int i = pos; i <= MAXL; i += lowbit(i)) a[i] = std::max(a[i], x);
    }
} bit;

inline void determineRequire() {
    static int len[MAXN + 1];

    bit.init(cnt);
    for (int i = 1; i <= n; i++) {
        len[i] += bit.query(a[i] - 1);
        bit.update(a[i], lisLeft[i]);
    }

    bit.init(cnt);
    for (int i = n; i >= 1; i--) {
        len[i] += bit.query(cnt - a[i] + 1 - 1);
        bit.update(cnt - a[i] + 1, lisRight[i]);
    }

    // for (int i = 1; i <= n; i++) printf("%d%c", len[i], " \n"[i == n]);

    static bool usable[MAXN + 1];
    for (int i = 1; i <= n; i++) if (len[i] + 1 == initAns) usable[i] = required[i] = true;

    int max = 0;
    for (int i = 1; i <= n; i++) {
        if (max >= a[i]) required[i] = false;
        if (usable[i]) max = std::max(max, a[i]);
    }

    int min = cnt + 1;
    for (int i = n; i >= 1; i--) {
        if (min <= a[i]) required[i] = false;
        if (usable[i]) min = std::min(min, a[i]);
    }
}

inline void solve() {
    std::sort(q + 1, q + m + 1);

    bit.init(cnt);
    for (int i = 1, j = 1; i <= n && j <= m; i++) {
        for (; j <= m && q[j].pos == i; j++) *q[j].ans += bit.query(q[j].newVal - 1);
        bit.update(a[i], lisLeft[i]);
    }

    bit.init(cnt);
    for (int i = n, j = m; i >= 1 && j >= 1; i--) {
        for (; j >= 1 && q[j].pos == i; j--) *q[j].ans += bit.query(cnt - q[j].newVal + 1 - 1);
        bit.update(cnt - a[i] + 1, lisRight[i]);
    }

    for (int i = 1; i <= m; i++) {
        (*q[i].ans)++;
        if (!required[q[i].pos]) *q[i].ans = std::max(*q[i].ans, initAns);
        else *q[i].ans = std::max(*q[i].ans, initAns - 1);
    }
}

int main() {
    scanf("%d %d", &n, &m);
    for (int i = 1; i <= n; i++) scanf("%d", &a[i]);

    static int ans[MAXN + 1];
    for (int i = 1; i <= m; i++) {
        scanf("%d %d", &q[i].pos, &q[i].newVal);
        q[i].ans = &ans[i];
    }

    discrete();

    initAns = processLis(a, lisLeft);

    static int revA[MAXN + 1];
    for (int i = 1; i <= n; i++) revA[i] = -a[n - i + 1];

    processLis(revA, lisRight);
    std::reverse(lisRight + 1, lisRight + n + 1);

    determineRequire();

    // for (int i = 1; i <= n; i++) printf("%d%c", lisLeft[i], " \n"[i == n]);
    // for (int i = 1; i <= n; i++) printf("%d%c", lisRight[i], " \n"[i == n]);
    // for (int i = 1; i <= n; i++) printf("%d%c", required[i], " \n"[i == n]);

    solve();

    for (int i = 1; i <= m; i++) printf("%d\n", ans[i]);
}