给定 个三元组 ,定义 当且仅当 且 且 。求 的最长上升子序列长度。
链接
题解
CDQ 分治,对于一段区间 ,先递归计算好 的答案,然后处理前一半对后一半的 DP 值的贡献,然后递归计算 的答案。
因为 DP 需要有序,所以不能先将两边处理完后再处理本层的影响,也不能归并排序。
注意 的划分,因为我们整个序列一开始按照 排序,所以必须保证划分的两个区间中,左边的所有 要小于右边,所以要在一个使两边 不同的位置划分 。
代码
#include <cstdio>
#include <algorithm>
const int MAXN = 50000;
struct Triple {
int a, b, c, *dp;
Triple() {}
Triple(int a, int b, int c, int *dp) : a(a), b(b), c(c), dp(dp) {}
} a[MAXN + 1];
int n, f[MAXN + 1];
inline void sort3(int &a, int &b, int &c) {
int min = std::min(a, std::min(b, c));
int max = std::max(a, std::max(b, c));
int mid;
if (a != min && a != max) mid = a;
else if (c != min && c != max) mid = c;
else mid = b;
a = min, b = mid, c = max;
}
inline void init(int x, int p) {
static long long exp[MAXN * 3 + 1];
exp[1] = x;
for (int i = 2; i <= 3 * n; i++) {
exp[i] = exp[i - 1] * x % p;
}
for (int i = 1; i <= n; i++) {
int x = exp[i * 3 - 2], y = exp[i * 3 - 1], z = exp[i * 3];
sort3(x, y, z);
a[i] = Triple(x, y, z, &f[i]);
}
}
inline void discrete() {
static int set[MAXN + 1];
for (int i = 1; i <= n; i++) set[i] = a[i].c;
std::sort(set + 1, set + n + 1);
int *end = std::unique(set + 1, set + n + 1);
for (int i = 1; i <= n; i++) a[i].c = std::lower_bound(set + 1, end, a[i].c) - set;
}
struct BIT {
int a[MAXN + 1];
static int lowbit(int x) {
return x & -x;
}
void update(int pos, int x) {
for (int i = pos; i <= n; i += lowbit(i)) a[i] = std::max(a[i], x);
}
int query(int pos) {
int res = 0;
for (int i = pos; i > 0; i -= lowbit(i)) res = std::max(res, a[i]);
return res;
}
void clear(int pos) {
for (int i = pos; i <= n; i += lowbit(i)) {
if (!a[i]) break;
a[i] = 0;
}
}
} bit;
inline bool compareByA(const Triple &a, const Triple &b) {
return a.a < b.a;
}
inline bool compareByB(const Triple &a, const Triple &b) {
return a.b < b.b;
}
inline void print(Triple *l, Triple *r) {
for (Triple *p = l; p <= r; p++) {
printf("%lu: (%d, %d, %d, f = %d)\n", p - a, p->a, p->b, p->c, *p->dp);
if (p == (l + (r - l) / 2)) puts("[mid]");
}
puts("");
}
inline Triple *divide(Triple *l, Triple *r) {
Triple *mid = l + (r - l) / 2;
while (mid < r && mid->a == (mid + 1)->a) mid++;
if (mid == r) {
mid = l + (r - l) / 2;
while (mid > l && mid->a == (mid - 1)->a) mid--;
if (mid == l) return NULL;
return mid;
}
return mid;
}
inline void cdq(Triple *l, Triple *r) {
if (l == r) {
return;
}
Triple *mid = divide(l, r);
if (!mid) return;
cdq(l, mid);
std::sort(l, mid + 1, compareByB);
std::sort(mid + 1, r + 1, compareByB);
// print(l, r);
for (Triple *p1 = l, *p2 = mid + 1; p1 <= mid || p2 <= r; ) {
if ((p1 <= mid && p1->b <= p2->b) || p2 > r) {
bit.update(p1->c, *p1->dp + 1);
p1++;
} else {
*p2->dp = std::max(*p2->dp, bit.query(p2->c - 1));
p2++;
}
}
for (Triple *q = l; q <= mid; q++) {
bit.clear(q->c);
}
std::sort(mid, r + 1, compareByA);
cdq(mid + 1, r);
}
/*
inline void force() {
for (int i = 1; i <= n; i++) {
for (int j = 0; j < i; j++) {
if (a[j].a < a[i].a && a[j].b < a[i].b && a[j].c < a[i].c) f[i] = std::max(f[i], f[j] + 1);
}
// printf("f[%d] = %d\n", i, f[i]);
}
}
*/
int main() {
int x, p;
scanf("%d %d %d", &x, &p, &n);
init(x, p);
// std::random_shuffle(a + 1, a + n + 1);
std::sort(a + 1, a + n + 1, compareByA);
discrete();
cdq(a + 1, a + n);
// force();
int ans = 0;
for (int i = 1; i <= n; i++) ans = std::max(ans, f[i] + 1); //, printf("(%d, %d, %d)\n", a[i].a, a[i].b, a[i].c);
printf("%d\n", ans);
return 0;
}