「省选模拟赛」扔鸡蛋 - DP

N 层楼,第 M 层以下扔鸡蛋会碎,你有 K 个鸡蛋,找出这个 M 需要多少次实验。

题解

f_{i, j} 表示用 i 个鸡蛋做 j 次实验最多能测试出多少层的楼,考虑第一个鸡蛋是否摔碎,即:

f_{i, j} = f_{i - 1, j - 1} + f_{i, j - 1} + 1

第二维大小取 10 ^ 5 即可,当询问 (N, K) 时,在 f_K 中二分查找 N 即可。

注意当 K = 1, 2, 3 时答案非常大,需要特判。

K = 1 时,答案为 N 。 当 K = 2 时,设答案为 x ,有:

N = \frac{x(x + 1)}{2} \\ x = \lceil \frac{\sqrt{8n + 1} - 1}{2} \rceil

K = 3 时,难以推出公式,我们可以再为 1 ~ 3 开一个 f 数组,第二维开到 2 \times 10 ^ 6 即可。

代码

#include <cstdio>
#include <cmath>
#include <algorithm>

const int MAXT = 100000;
const unsigned long long MAXN = 1e18;
const int LIM = 100000;
const int LIM2 = LIM * 20;
const int MAXK = 64;

template <typename T>
inline void read(T &x) {
    x = 0;
    register char ch;
    while (ch = getchar(), !(ch >= '0' && ch <= '9'));
    do x = x * 10 + (ch - '0'); while (ch = getchar(), (ch >= '0' && ch <= '9'));
}

template <typename T>
inline void write(T x) {
    char s[20];
    register int cnt = 0;
    do s[cnt++] = x % 10; while (x /= 10);
    while (cnt--) putchar(s[cnt] + '0');
    putchar('\n');
}

unsigned long long f[MAXK + 1][LIM + 1], f2[3 + 1][LIM2 + 1];
int cnt[MAXK + 1];

int main() {
    freopen("naive.in", "r", stdin);
    // freopen("naive.out", "w", stdout);

    for (int i = 1; i <= MAXK; i++) {
        cnt[i] = LIM;
        for (int j = 1; j <= LIM; j++) {
            f[i][j] = f[i - 1][j - 1] + f[i][j - 1] + 1;
            if (f[i][j] >= MAXN) {
                cnt[i] = j;
                break;
            }
        }
    }

    for (int i = 1; i <= 3; i++) {
        cnt[i] = LIM2;
        for (int j = 1; j <= LIM2; j++) {
            f2[i][j] = f2[i - 1][j - 1] + f2[i][j - 1] + 1;
            if (f2[i][j] >= MAXN) {
                cnt[i] = j;
                break;
            }
        }
    }

    // for (int i = 1; i <= 10; i++) write(cnt[i]);

    int t;
    scanf("%d", &t);
    while (t--) {
        unsigned long long n;
        int k;
        read(n), read(k);

        if (k == 1) write(n);
        else if (k == 2) write((long long)ceil((sqrt(8 * n + 1) - 1) / 2));
        else if (k == 3) {
            int ans = std::lower_bound(f2[k], f2[k] + cnt[k], n) - f2[k];
            printf("%d\n", ans);
        } else {
            int ans = std::lower_bound(f[k], f[k] + cnt[k], n) - f[k];
            printf("%d\n", ans);
        }
    }

    fclose(stdin);
    fclose(stdout);

    return 0;
}