「SDOI2016」储能表 - 二进制

有一个 列的表格,行从 编号,列从 编号。 每个格子都储存着能量。最初,第 行第 列的格子储存着 点能量。所以,整个表格储存的总能量是,

随着时间的推移,格子中的能量会渐渐减少。一个时间单位,每个格子中的能量都会减少 。显然,一个格子的能量减少到 之后就不会再减少了。 也就是说, 个时间单位后,整个表格储存的总能量是,

给出一个表格,求 个时间单位后它储存的总能量。 由于总能量可能较大,输出时对 取模。

链接

COGS 2220
BZOJ 4513

题解

正解是数位 DP …… 这里讲一种乱搞做法 ……

考虑异或的性质:

性质一:对于任意 ,必有

证明:之前两个数的从低到高第 位均为 ,现在均为 ,异或后结果不变。

性质二:对于任意 ,必有

证明:反证法,假设 ,则有 ,即 ,与题设矛盾。

性质三:对于任意 的所有数与 的异或所得结果取遍 的所有数。

证明:显然,有性质 2 可知这些数互不相同,并且二进制最多有 位,不可能大于等于 ,即这 个互不相同的数都在 内。

,为方便阅读, 表示列数, 表示行数。

打表找规律,先从最简单的开始搞。不考虑 ,当 )时,结果为

可以看到,整个矩阵的每一行包含了 的所有数字。直接使用等差数列求和公式计算即可。

稍复杂的情况,设 。当 时,打表结果为

左上角的黑色部分可以直接规约到第一种情况。红色的部分中,参与异或运算的一个数多了一个二进制位,根据性质三,这一部分能取到 的所有数。绿色部分同理。

对于黄色部分,相当于去掉了 的最高位后的一个子问题,递归计算即可。

更复杂的情况,当 (因为 所以不可能存在 )时,打表结果为

由性质三得,左边黑色部分取遍了 ,可用等差数列求和公式直接计算,右边部分都大于等于 ,将它们同时减去 后即为 的情况,递归处理后为每个数加上 即可。

现在考虑 对结果的影响,我们在计算一个等差数列 时,前面所有 的项都会变成 ,后面所有项减去 ,相当于一个以 开始,长度为 的等差数列,代入公式即可。

对于最后一种情况的递归,需要在 中将 的最高位去掉。最后为每个数加上时,对每个数的增量减去 即可。

每次递归时,会去掉 二进制最高位上的 ,其他的计算都可以在常数时间内完成,总时间复杂度为

代码

#include <cstdio>
#include <algorithm>

const int MAXT = 5000;
const long long MAXN = 1e18;
const long long MAXM = 1e18;
const long long MAXK = 1e18;
const int MAXP = 1e9;

long long p = MAXP;

/*
template <typename T>
inline int bitsCount(const T &x) {
    for (int i = sizeof(T) * 8 - 1; i >= 0; i--) if (x & ((T)1 << i)) return i + 1;
    return 0;
}
*/

template <typename T>
inline void bitsPrint(const T &x) {
    for (int i = sizeof(T) * 8 - 1; i >= 0; i--)
        if (x & ((T)1 << i)) putchar('1');
        else putchar('0');
    putchar('\n');
}

template <typename T>
inline T lowbit(const T &x) { return x & -x; }

template <typename T>
inline int log2(T x) {
    int ans = 0;
    while (x >>= 1) ans++;
    return ans;
}

template <typename T>
inline T mul(T x, T y, const T &z = 1) {
    // (x * y) / z, z is 1 or 2;
    if (z == 2) {
        if (x & 1) y >>= 1;
        else if (y & 1) x >>= 1;
        else throw;
    }
    return (x % p) * (y % p) % p;
}

inline long long sumTimes(long long first, long long n, const long long k, const long long t) {
   //  printf("from %lld to %lld = (%lld + %lld) * %lld / 2, and %lld times\n", first, first + n - 1, first, (first + n - 1), n, t);
    first -= k;
    if (first < 1) n -= (1 - first), first = 1;
    // printf("from %lld to %lld = (%lld + %lld) * %lld / 2, and %lld times\n", first, first + n - 1, first, (first + n - 1), n, t);
    if (n <= 0) return 0;
    return mul(mul(first + (first + n - 1), n, 2ll), t);
}

long long solve(long long n, long long m, long long k) {
    // printf("solve(%lld, %lld, %lld)\n", n, m, k);
    if (n == 0 || m == 0) return 0;

    if (k < 0) k = 0;
    if (n < m) std::swap(n, m);

    if (n == m && lowbit(n) == n) {
        return sumTimes(1, n - 1, k, m);
    }

    int N = log2(n), M = log2(m);
    long long centerWidth = (1ll << N), centerHeight = (1ll << M);
    if (N == M) {
        long long rightWidth = n - centerWidth, rightHeight = centerHeight;
        long long bottomWidth = centerWidth, bottomHeight = m - centerHeight;

        long long rightSum = sumTimes(centerWidth, rightHeight, k, rightWidth);
        long long bottomSum = sumTimes(centerHeight, bottomWidth, k, bottomHeight);

        long long sideSum = solve(rightWidth, bottomHeight, k);
        long long centerSum = solve(bottomWidth, rightHeight, k);

        return ((rightSum + bottomSum) % p + (sideSum + centerSum) % p) % p;
    } else {
        long long leftWidth = (1ll << N), leftHeight = m;
        long long rightWidth = n - leftWidth, rightHeight = leftHeight;

        long long leftSum = sumTimes(0, leftWidth, k, leftHeight);
        long long rightSum = solve(rightWidth, rightHeight, k - leftWidth);

        if (leftWidth > k) {
            rightSum += mul(mul(leftWidth - k, m), n - leftWidth);
            rightSum %= p;
        }

        return (leftSum + rightSum) % p;
    }
}

int main() {
    freopen("menci_table.in", "r", stdin);
    freopen("menci_table.out", "w", stdout);

    int t;
    scanf("%d", &t);
    while (t--) {
        long long n, m, k;
        scanf("%lld %lld %lld %lld", &n, &m, &k, &p);
        printf("%lld\n", solve(n, m, k));
    }

    // long long n, m, k;
    // scanf("%lld %lld %lld", &n, &m, &k);
    // bitsPrint(n), bitsPrint(m);

    // printf("lowbit(%lld) = %lld\nbitsCount(%lld) = %d\n2 ^ bitsCount(%lld) = %d\n", n, lowbit(n), n, bitsCount(n), n, 1 << (bitsCount(n)));
    // bitsPrint((1 << bitsCount(n)) - 1);

    /*long long ans = 0;
    for (int i = 0; i < n; i++) {
        for (int j = 0; j < m; j++) {
            int t = std::max((i ^ j) - k, 0ll);
            printf("%3d", t);
            ans += t;
        }
        putchar('\n');
    }*/

    // printf("ans = %lld\n", ans);

    // printf("%lld\n", solve(n, m, k));

    fclose(stdin);
    fclose(stdout);

    return 0;
}