主席树是一种数据结构,其主要应用是区间第 大问题。
权值线段树
传统的线段树用于维护一条线段上的区间,可以方便地查询区间信息。而如果将线段树转化为『权值线段树』,每个叶子节点存储某个元素出现次数,一条线段的总和表示区间内所有数出现次数的总和。
利用权值线段树可以方便地求出整体第 大 —— 从根节点向下走,如果 小于等于左子树大小,说明第 大在左子树的区间中,在左子树中继续查找即可;否则,说明第 大在右子树的区间中,此时将 减去左子树大小,并在右子树中继续查找。
查找过程类似平衡树,时间复杂度为 。
前缀和
上述算法可以用来处理整个序列上的第 大,而我们可以对于一个长度为 的序列 建立 棵上述的权值线段树,第 棵表示『 ~ 的所有数』组成的权值线段树。如果要查询 中的第 大,可以使用第 棵线段树减去第 棵线段树,得到整个区间组成的权值线段树,并进行上述算法得到区间中的第 大。
这个算法存在两个问题:
- 每个线段树要占用 的空间,算法的空间复杂度为 ,占用空间过多;
- 建立每棵线段树至少要用 的时间,每次查询又要用 的时间构建区间的权值线段树,总时间复杂度 。
看上去还不如每次直接提取出区间,并使用后线性选择得到答案的 的朴素算法优秀。
主席树
仔细思考,发现上述算法的 棵线段树中,相邻的两棵线段树仅有 个节点不同,因此本质不同的节点只有 个。我们可以充分利用这一特点,每次只重新创建与上次所不同的节点,相同的节点直接使用前一棵的即可。
为了节省空间,可以将第 棵线段树置为空,每次插入一个新叶子节点时接入一条长度为 的链。总空间、时间复杂度仍为
查询时构造整棵线段树,需要构造 个节点,但每次查询只会用到 个节点,直接动态构造这些节点即可。为了方便,可以不显式构造这些节点,而是直接用两棵线段树上的值相减。
模板
POJ 2104
动态分配内存会超时,需要静态分配内存。
#include <cstdio>
#include <climits>
#include <algorithm>
#include <new>
const int MAXN = 100000;
const int MAXM = 5000;
template <size_t SIZE>
struct MemoryPool {
char buf[SIZE], *cur;
MemoryPool() : cur(buf) {}
void *alloc(const int size) {
if (cur == buf + SIZE) return malloc(size);
else {
char *p = cur;
cur += size;
return p;
}
}
};
MemoryPool<(4 + 4 + 8 + 8 + 4) * MAXN * 10> pool;
struct ChairmanTree {
struct Node {
int l, r;
Node *lc, *rc;
int cnt;
Node(const int l, const int r, Node *lc = NULL, Node *rc = NULL) : l(l), r(r), lc(lc), rc(rc), cnt((lc ? lc->cnt : 0) + (rc ? rc->cnt : 0)) {}
Node(const int l, const int r, const int cnt) : l(l), r(r), lc(NULL), rc(NULL), cnt(cnt) {}
void pushDown() {
if (lc && rc) return;
int mid = l + ((r - l) >> 1);
if (!lc) lc = new (pool.alloc(sizeof(Node))) Node(l, mid);
if (!rc) rc = new (pool.alloc(sizeof(Node))) Node(mid + 1, r);
}
Node *insert(const int num) {
if (num < l || num > r) return this;
else if (num == l && num == r) return new (pool.alloc(sizeof(Node))) Node(l, r, this->cnt + 1);
else {
const int mid = l + ((r - l) >> 1);
pushDown();
if (num <= mid) return new (pool.alloc(sizeof(Node))) Node(l, r, lc->insert(num), rc);
else return new (pool.alloc(sizeof(Node))) Node(l, r, lc, rc->insert(num));
}
}
int rank() const {
return lc ? lc->cnt : 0;
}
} *root[MAXN + 1];
int n;
void build(const int a[], const int n) {
this->n = n;
root[0] = new (pool.alloc(sizeof(Node))) Node(0, n - 1);
for (int i = 1; i <= n; i++) {
root[i] = root[i - 1]->insert(a[i - 1]);
}
}
int query(const int l, const int r, int k) {
Node *L = root[l - 1], *R = root[r];
int min = 0, max = n - 1;
while (min != max) {
L->pushDown(), R->pushDown();
int mid = min + ((max - min) >> 1), t = R->rank() - L->rank();
if (k <= t) L = L->lc, R = R->lc, max = mid;
else k -= t, L = L->rc, R = R->rc, min = mid + 1;
}
return min;
}
} t;
int n, m, a[MAXN];
int main() {
scanf("%d %d", &n, &m);
for (int i = 0; i < n; i++) scanf("%d", &a[i]);
static int set[MAXN];
std::copy(a, a + n, set);
std::sort(set, set + n);
int *end = std::unique(set, set + n);
for (int i = 0; i < n; i++) a[i] = std::lower_bound(set, end, a[i]) - set;
t.build(a, n);
for (int i = 0; i < m; i++) {
int l, r, k;
scanf("%d %d %d", &l, &r, &k);
int ans = t.query(l, r, k);
printf("%d\n", set[ans]);
}
return 0;
}