「BZOJ 3156」防御准备 - 斜率优化 DP

我们定义战线为一条长度为 n 的序列,在这条战线上共设有 n 个检查点,从左到右依次标号为 1 n 。一个战线为合法战线当且仅当任意一个检查点可以通过安全检查。对于第 i 个点,通过安全检查的方法有两种,第一种是放置一个守卫塔,这将花费 c(i) 的费用,第二种方式是放置一个木偶,放置木偶的花费等于这个检查点右侧的第一个守卫塔到它的距离。第 n 个点只能放置守卫塔。求最小的战线花费值。

链接

BZOJ 3156

题解

将整个序列反转,放置木偶的花费等于这个检查点左侧的第一个守卫塔到它的距离,一号检查点必须放置守卫塔。

f(i) 表示前 i 个检查点通过检查的最小代价,枚举 j ,在 j + 1 位置放置一个守卫塔,之后一直到 i 的位置全部放置木偶。

f(i) = \min\limits_{j = 0} ^ {i - 1} \{ f(j) + c(j + 1) + \frac{(i - j - 1)(i - j)}{2} \}

斜率优化,考虑两个决策 a b a > b ),若 a b 优,则有

\begin{align*} f(a) + c(a + 1) + \frac{(i - a - 1)(i - a)}{2} &< f(b) + c(b + 1) + \frac{(i - b - 1)(i - b)}{2} \\ f(a) + c(a + 1) + \frac{i ^ 2}{2} + \frac{a ^ 2}{2} - ia - \frac{i}{2} + \frac{a}{2} &< f(b) + c(b + 1) + \frac{i ^ 2}{2} + \frac{b ^ 2}{2} - ib - \frac{i}{2} + \frac{b}{2} \\ { (f(a) + c(a + 1) + \frac{a ^ 2}{2} + \frac{a}{2}) - (f(b) + c(b + 1) + \frac{b ^ 2}{2} + \frac{b}{2}) \over a - b } &< i \\ \end{align*}

维护决策点,使斜率递增,最优决策取队首。时间复杂度为 O(n)

代码

#include <cstdio>
#include <climits>
#include <algorithm>

const int MAXN = 1000000;

int n;
long long c[MAXN + 1], f[MAXN + 1];

inline void prepare() {
    std::reverse(c + 1, c + n + 1);
}

/*
inline void force() {
    std::fill(f, f + n + 1, LLONG_MAX);
    f[0] = 0;
    for (int i = 1; i <= n; i++) {
        int _j = -1;
        for (int j = 0; j < i; j++) {
            if (f[i] > f[j] + c[j + 1] + (i - j - 1) * (i - j) / 2) {
                f[i] = f[j] + c[j + 1] + (i - j - 1) * (i - j) / 2;
                _j = j;
            }
        }

        printf("%d --> %d\n", i, _j);
    }
}
*/

template <typename T> inline T sqr(const T &x) { return x * x; }

inline double x(const int i) { return f[i] + c[i + 1] + sqr(static_cast<long long>(i)) * 0.5 + i * 0.5; }

inline double slope(const int a, const int b) {
    return double(x(a) - x(b))
         / double(a - b);
}

inline void dp() {
    std::fill(f, f + n + 1, LLONG_MAX);
    f[0] = 0;

    static int q[MAXN];
    int *l = q, *r = q;
    *r = 0;

    for (int i = 1; i <= n; i++) {
        while (l < r && slope(*(l + 1), *l) < i) l++;

        int &j = *l;
        f[i] = f[j] + c[j + 1] + (i - j - 1) * (i - j) / 2;
        // printf("%d --> %d\n", i, j);

        if (i < n) {
            while (l < r && slope(*r, *(r - 1)) > slope(i, *r)) r--;
            *++r = i;
        }
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) scanf("%lld", &c[i]);

    prepare();

    // force();
    dp();
    printf("%lld\n", f[n]);

    return 0;
}