「BZOJ 3674」可持久化并查集加强版 - 可持久化线段树

n 个集合, m 个操作:

  • 1 a b 合并 a b 所在集合;
  • 2 k 回到第 k 次操作之后的状态(查询算作操作);
  • 3 a b 询问 a b 是否属于同一集合,是则输出 1 否则输出 0

链接

BZOJ 3674

题解

可持久化线段树实现可持久化数组,然后实现可持久化并查集。

注意时间的映射关系。

代码

#include <cstdio>
#include <cassert>
#include <climits>
#include <algorithm>
#include <new>

const int MAXN = 2e5;
const int MAXM = 2e5;

struct PSegT
{
    struct Node
    {
        int l, r;
        Node *lc, *rc;
        int val;

        Node() {}
        Node(int l, int r, Node *lc, Node *rc, int val) : l(l), r(r), lc(lc), rc(rc), val(val) {}

        int query(int pos)
        {
            if (pos < l || pos > r) return 0;
            else if (l == r) return val;
            else
            {
                return lc->query(pos) + rc->query(pos);
            }
        }
    } *roots[MAXM * 10 + 1], *null, _pool[MAXN * 30], *_cur;

    int time, l, r;

    PSegT() : time(0)
    {
        null = new Node(-1, -1, NULL, NULL, 0);
        null->lc = null;
        null->rc = null;
    }

    Node *insert(Node *v, int l, int r, int pos, int val)
    {
        if (pos < l || pos > r)
        {
            return v;
        }
        else if (l == r)
        {
            return new (_cur++) Node(l, r, null, null, val);
        }
        else
        {
            int mid = l + (r - l) / 2;
            return new (_cur++) Node(l, r, insert(v->lc, l, mid, pos, val), insert(v->rc, mid + 1, r, pos, val), 0);
        }
    }

    void init(int l, int r)
    {
        this->l = l;
        this->r = r;
        time = 0;

        roots[0] = null;

        _cur = _pool;
    }

    void update(int fromTime, int pos, int val)
    {
        roots[++time] = insert(roots[fromTime], l, r, pos, val);
    }

    int query(int fromTime, int pos)
    {
        return roots[fromTime]->query(pos);
    }

    int getTime()
    {
        return time;
    }
};

struct UFS
{
    PSegT fa, rank;
    int timeFa[MAXM + 1], timeRank[MAXM + 1];
    int time;

    void init(int n)
    {
        fa.init(1, n);
        rank.init(1, n);

        for (int i = 1; i <= n; i++)
        {
            int t = fa.getTime();
            fa.update(t, i, i);
            rank.update(t, i, 1);
        }

        time = 0;
        timeFa[time] = fa.getTime();
        timeRank[time] = rank.getTime();
    }

    int find(int fromTime, int x)
    {
        int tmp = fa.query(timeFa[fromTime], x);
        if (tmp == x) return x;
        else return find(fromTime, tmp);
    }

    int merge(int fromTime, int x, int y)
    {
        time++;

        int a = find(fromTime, x), b = find(fromTime, y);
        int ra = rank.query(timeRank[fromTime], a), rb = rank.query(timeRank[fromTime], b);
        if (ra > rb)
        {
            std::swap(a, b);
        }

        if (ra == rb)
        {
            rank.update(timeRank[fromTime], b, rb + 1);
        }

        fa.update(timeFa[fromTime], a, b);

        timeFa[time] = fa.getTime();
        timeRank[time] = rank.getTime();

        return time;
    }

    int getTime()
    {
        return time;
    }
} ufs;

int main()
{
    int n, m;
    scanf("%d %d", &n, &m);

    static int time[MAXM + 1];
    ufs.init(n);
    time[0] = ufs.getTime();

    int lastAns = 0;
    for (int i = 1; i <= m; i++)
    {
        int c, a;
        scanf("%d %d", &c, &a);
        a ^= lastAns;
        if (c == 1)
        {
            int b;
            scanf("%d", &b);
            b ^= lastAns;

            time[i] = ufs.merge(time[i - 1], a, b);
        }
        else if (c == 2)
        {
            time[i] = time[a];
        }
        else
        {
            int b;
            scanf("%d", &b);
            b ^= lastAns;

            int fa = ufs.find(time[i - 1], a), fb = ufs.find(time[i - 1], b);

            int res = fa == fb;
            printf("%d\n", res);

            lastAns = res;

            time[i] = time[i - 1];
        }
    }

    return 0;
}