「BZOJ 2253」纸箱堆叠 - CDQ 分治 + DP

给定 n 个三元组 A_i = (a_i, b_i, c_i) ,定义 A_i < A_j 当且仅当 a_i < a_j b_i < b_j c_i < c_j 。求 A 的最长上升子序列长度。

链接

BZOJ 2253

题解

CDQ 分治,对于一段区间 (l, r) ,先递归计算好 (l, \text{mid}) 的答案,然后处理前一半对后一半的 DP 值的贡献,然后递归计算 (\text{mid} + 1, r) 的答案。

因为 DP 需要有序,所以不能先将两边处理完后再处理本层的影响,也不能归并排序。

注意 \text{mid} 的划分,因为我们整个序列一开始按照 a_i 排序,所以必须保证划分的两个区间中,左边的所有 a 要小于右边,所以要在一个使两边 a_i 不同的位置划分 \text{mid}

代码

#include <cstdio>
#include <algorithm>

const int MAXN = 50000;

struct Triple {
    int a, b, c, *dp;

    Triple() {}
    Triple(int a, int b, int c, int *dp) : a(a), b(b), c(c), dp(dp) {}
} a[MAXN + 1];

int n, f[MAXN + 1];

inline void sort3(int &a, int &b, int &c) {
    int min = std::min(a, std::min(b, c));
    int max = std::max(a, std::max(b, c));
    int mid;
    if (a != min && a != max) mid = a;
    else if (c != min && c != max) mid = c;
    else mid = b;

    a = min, b = mid, c = max;
}

inline void init(int x, int p) {
    static long long exp[MAXN * 3 + 1];
    exp[1] = x;
    for (int i = 2; i <= 3 * n; i++) {
        exp[i] = exp[i - 1] * x % p;
    }

    for (int i = 1; i <= n; i++) {
        int x = exp[i * 3 - 2], y = exp[i * 3 - 1], z = exp[i * 3];
        sort3(x, y, z);
        a[i] = Triple(x, y, z, &f[i]);
    }
}

inline void discrete() {
    static int set[MAXN + 1];
    for (int i = 1; i <= n; i++) set[i] = a[i].c;
    std::sort(set + 1, set + n + 1);
    int *end = std::unique(set + 1, set + n + 1);
    for (int i = 1; i <= n; i++) a[i].c = std::lower_bound(set + 1, end, a[i].c) - set;
}

struct BIT {
    int a[MAXN + 1];

    static int lowbit(int x) {
        return x & -x;
    }

    void update(int pos, int x) {
        for (int i = pos; i <= n; i += lowbit(i)) a[i] = std::max(a[i], x);
    }

    int query(int pos) {
        int res = 0;
        for (int i = pos; i > 0; i -= lowbit(i)) res = std::max(res, a[i]);
        return res;
    }

    void clear(int pos) {
        for (int i = pos; i <= n; i += lowbit(i)) {
            if (!a[i]) break;
            a[i] = 0;
        }
    }
} bit;

inline bool compareByA(const Triple &a, const Triple &b) {
    return a.a < b.a;
}

inline bool compareByB(const Triple &a, const Triple &b) {
    return a.b < b.b;
}

inline void print(Triple *l, Triple *r) {
    for (Triple *p = l; p <= r; p++) {
        printf("%lu: (%d, %d, %d, f = %d)\n", p - a, p->a, p->b, p->c, *p->dp);
        if (p == (l + (r - l) / 2)) puts("[mid]");
    }
    puts("");
}

inline Triple *divide(Triple *l, Triple *r) {
    Triple *mid = l + (r - l) / 2;
    while (mid < r && mid->a == (mid + 1)->a) mid++;
    if (mid == r) {
        mid = l + (r - l) / 2;
        while (mid > l && mid->a == (mid - 1)->a) mid--;
        if (mid == l) return NULL;
        return mid;
    }
    return mid;
}

inline void cdq(Triple *l, Triple *r) {
    if (l == r) {
        return;
    }

    Triple *mid = divide(l, r);
    if (!mid) return;

    cdq(l, mid);

    std::sort(l, mid + 1, compareByB);
    std::sort(mid + 1, r + 1, compareByB);

    // print(l, r);

    for (Triple *p1 = l, *p2 = mid + 1; p1 <= mid || p2 <= r; ) {
        if ((p1 <= mid && p1->b <= p2->b) || p2 > r) {
            bit.update(p1->c, *p1->dp + 1);
            p1++;
        } else {
            *p2->dp = std::max(*p2->dp, bit.query(p2->c - 1));
            p2++;
        }
    }

    for (Triple *q = l; q <= mid; q++) {
        bit.clear(q->c);
    }

    std::sort(mid, r + 1, compareByA);

    cdq(mid + 1, r);
}

/*
inline void force() {
    for (int i = 1; i <= n; i++) {
        for (int j = 0; j < i; j++) {
            if (a[j].a < a[i].a && a[j].b < a[i].b && a[j].c < a[i].c) f[i] = std::max(f[i], f[j] + 1);
        }
        // printf("f[%d] = %d\n", i, f[i]);
    }
}
*/

int main() {
    int x, p;
    scanf("%d %d %d", &x, &p, &n);
    init(x, p);

    // std::random_shuffle(a + 1, a + n + 1);
    std::sort(a + 1, a + n + 1, compareByA);

    discrete();

    cdq(a + 1, a + n);

    // force();

    int ans = 0;
    for (int i = 1; i <= n; i++) ans = std::max(ans, f[i] + 1); //, printf("(%d, %d, %d)\n", a[i].a, a[i].b, a[i].c);
    printf("%d\n", ans);

    return 0;
}