「BZOJ 1756」小白逛公园 - 线段树

路的一边从南到北依次排着 n 个公园,一开始,小白就根据公园的风景给每个公园打了分。小新为了省事,每次遛狗的时候都会事先规定一个范围,小白只可以选择第 a 个和第 b 个公园之间(包括 ab 两个公园)选择连续的一些公园玩。小白当然希望选出的公园的分数总和尽量高咯。同时,由于一些公园的景观会有所改变,所以,小白的打分也可能会有一些变化。那么,就请你来帮小白选择公园吧。

链接

BZOJ 1756

题解

区间内最大连续和,还带修改,当然是线段树咯!

每个节点维护以下几项信息:

  1. 区间总和;
  2. 区间最大连续和;
  3. 强制包含左端点的最大连续和;
  4. 强制包含右端点的最大连续和。

然后使用动态规划的方式求出每个节点的四个值即可。

查询麻烦一点,如果跨左右子树查询的话,需要维护要查询的区间的以上四项值,然后用相似的方式向上传递。

合并两个区间时,需要注意细节。

还有就是读入 ab 时,有可能 ab 大!

代码

#include <cstdio>
#include <climits>
#include <algorithm>

const int MAXN = 500000;
const int MAXM = 100000;

struct SegmentTree {
    struct Node {
        Node *lchild, *rchild;
        int l, r;
        int value, sum, lsum, rsum, maxSum;

        Node(int l, int r, Node *lchild, Node *rchild) : l(l), r(r), lchild(lchild), rchild(rchild) {}

        void maintain() {
            sum = value;
            if (lchild) sum += lchild->sum;
            if (rchild) sum += rchild->sum;

            lsum = rsum = sum;
            if (lchild) {
                lsum = std::max(lsum, lchild->lsum);
                if (rchild) lsum = std::max(lsum, lchild->sum + rchild->lsum);
            }
            if (rchild) {
                rsum = std::max(rsum, rchild->rsum);
                if (lchild) rsum = std::max(rsum, rchild->sum + lchild->rsum);
            }

            maxSum = sum;
            maxSum = std::max(maxSum, lsum);
            maxSum = std::max(maxSum, rsum);
            if (lchild) maxSum = std::max(maxSum, lchild->maxSum);
            if (rchild) maxSum = std::max(maxSum, rchild->maxSum);
            if (lchild && rchild) maxSum = std::max(maxSum, lchild->rsum + rchild->lsum);
        }

        void query(int l, int r, int &sum, int &lsum, int &rsum, int &maxSum) {
            if (this->l > r || this->r < l) throw;
            else if (this->l >= l && this->r <= r) {
                sum = this->sum;
                lsum = this->lsum;
                rsum = this->rsum;
                maxSum = this->maxSum;
            } else {
                int mid = this->l + ((this->r - this->l) >> 1);
                if (r <= mid) return lchild->query(l, r, sum, lsum, rsum, maxSum);
                if (l >= mid + 1) return rchild->query(l, r, sum, lsum, rsum, maxSum);
                else {
                    int suml, lsuml, rsuml, maxSuml;
                    int sumr, lsumr, rsumr, maxSumr;
                    lchild->query(l, r, suml, lsuml, rsuml, maxSuml);
                    rchild->query(l, r, sumr, lsumr, rsumr, maxSumr);

                    maxSum = sum = suml + sumr;
                    lsum = std::max(lsuml, suml + lsumr);
                    rsum = std::max(rsumr, sumr + rsuml);
                    maxSum = std::max(maxSum, maxSuml);
                    maxSum = std::max(maxSum, maxSumr);
                    maxSum = std::max(maxSum, lsumr + rsuml);
                }
            }
        }

        void update(int pos, int value) {
            if (this->l > pos || this->r < pos) return;
            else if (this->l == pos && this->r == pos) this->value = value, maintain();
            else {
                if (lchild) lchild->update(pos, value);
                if (rchild) rchild->update(pos, value);
                maintain();
            }
        }
    } *root;

    SegmentTree(int l, int r) {
        root = build(l, r);
    }

    Node *build(int l, int r) {
        if (l > r) return NULL;
        else if (l == r) return new Node(l, r, NULL, NULL);
        else return new Node(l, r, build(l, l + ((r - l) >> 1)), build(l + ((r - l) >> 1) + 1, r));
    }

    void update(int pos, int value) {
        root->update(pos, value);
    }

    int query(int l, int r) {
        int sum, lsum, rsum, maxSum;
        root->query(l, r, sum, lsum, rsum, maxSum);
        return maxSum;
    }
};

int main() {
    int n, m;

    scanf("%d %d", &n, &m);

    SegmentTree *segment = new SegmentTree(1, n);
    for (int i = 1; i <= n; i++) {
        int x;
        scanf("%d", &x);
        segment->update(i, x);
    }

    for (int i = 0; i < m; i++) {
        int k;
        scanf("%d", &k);
        if (k == 1) {
            int l, r;
            scanf("%d %d", &l, &r);
            printf("%d\n", segment->query(std::min(l, r), std::max(l, r)));
        } else {
            int p, s;
            scanf("%d %d", &p, &s);
            segment->update(p, s);
        }
    }

    return 0;
}