「HDU 632」Rikka with Array - 数位 DP

A(x) 表示 x 的二进制表示中 1 的数量,求满足 1 \leq i \lt j \leq n,\ A(i) \gt A(j) 的数对 [i, j] 的数量。

链接

HDU 5632

题解

数位 DP,设

f[n][gap][iLessThanJ][lim1][lim2]

表示二进制表示的最后 n 位,之前各位组成的数的差( j - i )为 gap ,是否 i \lt j i = j i j 的之前所有位是否均达到上界的合法数对数量。

转移时枚举 i j 的下一位分别是 0 或者 1 即可。

注意,不需要考虑 i \gt j 的情况,因为确定了二进制较高位满足一大一小后,较低位不会使其大小关系更改,即这种状态不可能合法。

代码

#include <cstdio>
#include <cstring>
#include <algorithm>

const int MAXT = 10;
const int MAXN = 300;
const int MAXN_BIT = 997; // 996.5784284662087
const int MOD = 998244353;

const int FLAG = 2;

int n;
bool a[MAXN_BIT];

int mem[MAXN_BIT][(MAXN_BIT + 1) * 2][FLAG][FLAG][FLAG];
// bool calced[MAXN_BIT][(MAXN_BIT + 1) * 2][FLAG][FLAG][FLAG];

bool i[MAXN_BIT], j[MAXN_BIT];

/*
inline void _set(const int n, const bool i, const bool j) {
    ::i[::n - n] = i;
    ::j[::n - n] = j;
}

inline void _print() {
    for (int i = 0; i < n; i++) putchar(::i[i] ? '1' : '0');
    putchar('\n');

    for (int i = 0; i < n; i++) putchar(::j[i] ? '1' : '0');
    putchar('\n');

    putchar('\n');
}
*/

inline int dp(const int n, const int gap, const bool iLessThanJ, const bool limited1, const bool limited2) {
    // gap : (A[i] - A[j])

    int &ans = mem[n][gap + MAXN_BIT][iLessThanJ][limited1][limited2];
    if (ans != -1) return ans;

    if (n == 0) {
        if (iLessThanJ && gap > 0) ans = 1; // , _print();
        else ans = 0;
    } else {
        ans = 0;

        int limit1, limit2;
        bool &next = a[::n - n];

        if (limited1) limit1 = next;
        else limit1 = 1;

        if (limited2) limit2 = next;
        else limit2 = 1;

        // printf("next = %d\n", next);

        // 0 0
        // _set(n, 0, 0);
        ans += dp(n - 1,
                  gap,
                  iLessThanJ,
                  limited1 && next == 0,
                  limited2 && next == 0
        );
        ans %= MOD;

        // 0 1
        if (limit2 == 1) {
            // _set(n, 0, 1);
            ans += dp(n - 1,
                      gap - 1,
                      true,
                      limited1 && next == 0,
                      limited2 && next == 1
            );
            ans %= MOD;
        }


        // 1 0
        if (limit1 == 1 && iLessThanJ) {
            // _set(n, 1, 0);
            ans += dp(n - 1,
                      gap + 1,
                      true,
                      limited1 && next == 1,
                      limited2 && next == 0
            );
            ans %= MOD;
        }


        // 1 1
        if (limit1 == 1 && limit2 == 1) {
            // _set(n, 1, 1);
            ans += dp(n - 1,
                      gap,
                      iLessThanJ,
                      limited1 && next == 1,
                      limited2 && next == 1
            );
            ans %= MOD;
        }
    }

    // if (n == ::n - 1) printf("f[%d][%d][%s][%s][%s] = %d\n", n, gap, iLessThanJ ? "<" : "=", limited1 ? "*" : " ", limited2 ? "*" : " ", ans);
    return ans;
}

inline int solve(const char *s) {
    memset(mem, 0xff, sizeof(mem));
    // memset(calced, 0, sizeof(calced));
    n = 0;

    int len = strlen(s);
    static int num[MAXN];
    for (int i = 0; i < len; i++) num[i] = s[i] - '0';

    while (1) {
        int r = 0;
        bool allZero = true;
        for (int i = 0; i < len; i++) {
            if (num[i] != 0) allZero = false;
            r = r * 10 + num[i];
            num[i] = r / 2;
            r %= 2;
        }

        if (allZero) break;

        a[n++] = r;
    }

    std::reverse(a, a + n);

    if (n == 0) return 0;

    int ans = 0;

    // 0 0
    // _set(n, 0, 0);
    ans += dp(
        n - 1,
        0,
        false,
        false,
        false
    );
    ans %= MOD;

    // 0 1
    // _set(n, 0, 1);
    ans += dp(
        n - 1,
        -1,
        true,
        false,
        true
    );
    ans %= MOD;

    // 1 1
    // _set(n, 1, 1);
    ans += dp(
        n - 1,
        0,
        false,
        true,
        true
    );
    ans %= MOD;

    return ans;
}

int main() {
    int t;
    scanf("%d", &t);
    while (t--) {
        static char s[MAXN + 1];
        scanf("%s", s);
        printf("%d\n", solve(s));
    }

    return 0;
}