设 表示 的二进制表示中 的数量,求满足 的数对 的数量。
链接
题解
数位 DP,设
表示二进制表示的最后 位,之前各位组成的数的差()为 ,是否 或 ,、 的之前所有位是否均达到上界的合法数对数量。
转移时枚举 、 的下一位分别是 或者 即可。
注意,不需要考虑 的情况,因为确定了二进制较高位满足一大一小后,较低位不会使其大小关系更改,即这种状态不可能合法。
代码
#include <cstdio>
#include <cstring>
#include <algorithm>
const int MAXT = 10;
const int MAXN = 300;
const int MAXN_BIT = 997; // 996.5784284662087
const int MOD = 998244353;
const int FLAG = 2;
int n;
bool a[MAXN_BIT];
int mem[MAXN_BIT][(MAXN_BIT + 1) * 2][FLAG][FLAG][FLAG];
// bool calced[MAXN_BIT][(MAXN_BIT + 1) * 2][FLAG][FLAG][FLAG];
bool i[MAXN_BIT], j[MAXN_BIT];
/*
inline void _set(const int n, const bool i, const bool j) {
::i[::n - n] = i;
::j[::n - n] = j;
}
inline void _print() {
for (int i = 0; i < n; i++) putchar(::i[i] ? '1' : '0');
putchar('\n');
for (int i = 0; i < n; i++) putchar(::j[i] ? '1' : '0');
putchar('\n');
putchar('\n');
}
*/
inline int dp(const int n, const int gap, const bool iLessThanJ, const bool limited1, const bool limited2) {
// gap : (A[i] - A[j])
int &ans = mem[n][gap + MAXN_BIT][iLessThanJ][limited1][limited2];
if (ans != -1) return ans;
if (n == 0) {
if (iLessThanJ && gap > 0) ans = 1; // , _print();
else ans = 0;
} else {
ans = 0;
int limit1, limit2;
bool &next = a[::n - n];
if (limited1) limit1 = next;
else limit1 = 1;
if (limited2) limit2 = next;
else limit2 = 1;
// printf("next = %d\n", next);
// 0 0
// _set(n, 0, 0);
ans += dp(n - 1,
gap,
iLessThanJ,
limited1 && next == 0,
limited2 && next == 0
);
ans %= MOD;
// 0 1
if (limit2 == 1) {
// _set(n, 0, 1);
ans += dp(n - 1,
gap - 1,
true,
limited1 && next == 0,
limited2 && next == 1
);
ans %= MOD;
}
// 1 0
if (limit1 == 1 && iLessThanJ) {
// _set(n, 1, 0);
ans += dp(n - 1,
gap + 1,
true,
limited1 && next == 1,
limited2 && next == 0
);
ans %= MOD;
}
// 1 1
if (limit1 == 1 && limit2 == 1) {
// _set(n, 1, 1);
ans += dp(n - 1,
gap,
iLessThanJ,
limited1 && next == 1,
limited2 && next == 1
);
ans %= MOD;
}
}
// if (n == ::n - 1) printf("f[%d][%d][%s][%s][%s] = %d\n", n, gap, iLessThanJ ? "<" : "=", limited1 ? "*" : " ", limited2 ? "*" : " ", ans);
return ans;
}
inline int solve(const char *s) {
memset(mem, 0xff, sizeof(mem));
// memset(calced, 0, sizeof(calced));
n = 0;
int len = strlen(s);
static int num[MAXN];
for (int i = 0; i < len; i++) num[i] = s[i] - '0';
while (1) {
int r = 0;
bool allZero = true;
for (int i = 0; i < len; i++) {
if (num[i] != 0) allZero = false;
r = r * 10 + num[i];
num[i] = r / 2;
r %= 2;
}
if (allZero) break;
a[n++] = r;
}
std::reverse(a, a + n);
if (n == 0) return 0;
int ans = 0;
// 0 0
// _set(n, 0, 0);
ans += dp(
n - 1,
0,
false,
false,
false
);
ans %= MOD;
// 0 1
// _set(n, 0, 1);
ans += dp(
n - 1,
-1,
true,
false,
true
);
ans %= MOD;
// 1 1
// _set(n, 1, 1);
ans += dp(
n - 1,
0,
false,
true,
true
);
ans %= MOD;
return ans;
}
int main() {
int t;
scanf("%d", &t);
while (t--) {
static char s[MAXN + 1];
scanf("%s", s);
printf("%d\n", solve(s));
}
return 0;
}