「HDU 3949」XOR - 线性基

给一组正整数,求其第 $ k $ 小的子集异或和。

链接

HDU 3949Virtual Judge

题解

首先,求出这个集合的线性基 $ B $,选择线性基的一个非空子集共有 $ 2 ^ {|B|} - 1 $ 种方案。如果 $ |B| < n $,则说明至少有一个没有被插入到线性基中的数可以被线性基中的数表示出来,选择线性基中的一些数与这个数,可以得到其异或和为 $ 0 $,这样有 $ 2 ^ {|B|} $ 种方案。

然后,考虑给出线性基,求选择若干(至少一个)数可以组成的第 $ k $ 小($ k $ 从 $ 0 $ 开始,第 $ 0 $ 小为 $ 0 $)的数。

将 $ k $ 表示为一个长度为 $ |B| $ 的二进制数(范围为 $ [0, |B|) $;如不足,可在高位补 $ 0 $)。 $ k $ 的二进制排列符合以下性质:

  1. 高位上的 $ 1 $ 比低位上的 $ 1 $ 更能使 $ k $ 更大;
  2. 低位上的 $ 1 $ 一定会使 $ k $ 更大。

线性基的 $ |B| $ 个元素控制了异或后结果的 $ |B| $ 个二进制位,而二进制数的规律恰好与从线性基中选数的两条规律相对应:

  1. 选择「控制较高位上的 $ 1 $ 的元素」更能使异或和更大;
  2. 选择「控制较高位上的 $ 1 $ 的元素」后,再选择「控制更低位上 $ 1 $ 的元素」一定会使异或和更大。

解法就比较显然了 —— 枚举 $ k $ 所有为 $ 1 $ 的二进制位,如果第 $ i $ 位为 $ 1 $,则将线性基中控制的第 $ i $ 小的二进制位的元素异或到答案中。

代码

#include <cstdio>
#include <vector>

const int MAXN = 100000;
const int MAXL = 50;

struct LinearBasis {
    std::vector<long long> v;
    int n;

    void build(long long *x, int n) {
        this->n = n;

        std::vector<long long> a(MAXL + 1);

        for (int i = 1; i <= n; i++) {
            long long t = x[i];

            for (int j = MAXL; j >= 0; j--) {
                if ((t & (1ll << j)) == 0) continue;

                if (a[j]) t ^= a[j];
                else {
                    for (int k = 0; k < j; k++) if (t & (1ll << k)) t ^= a[k];
                    for (int k = j + 1; k <= MAXL; k++) if (a[k] & (1ll << j)) a[k] ^= t;

                    a[j] = t;

                    break;
                }
            }

            /*
            printf("insert(%d):\n", t);
            print();
            */
        }

        v.clear();
        for (int i = 0; i <= MAXL; i++) if (a[i]) v.push_back(a[i]);
    }

    long long query(long long k) {
        if (int(v.size()) != n) {
            // 可能是 0
            k--;
        }

        // 如果 k 超过的所有不同异或和的数量
        if (k > (1ll << v.size()) - 1) return -1;

        long long ans = 0;
        for (size_t i = 0; i < v.size(); i++) {
            if (k & (1ll << i)) {
                ans ^= v[i];
            }
        }

        return ans;
    }

    /*
    void print() {
        for (int i = 0; i <= MAXL; i++) {
            for (int j = 0; j <= MAXL; j++) printf(j == MAXL ? "%d\n" : "%d", (a[i] & (1 << j)) ? 1 : 0);
        }
    }
    */
} lb;

int main() {
    int n;
    scanf("%d", &n);

    static long long a[MAXN + 1];
    for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);

    lb.build(a, n);

    int m;
    scanf("%d", &m);
    while (m--) {
        long long k;
        scanf("%lld", &k);
        printf("%lld\n", lb.query(k));
    }

    return 0;
}