「BZOJ 4499」线性函数 - 线段树

小 C 最近在学习线性函数,线性函数可以表示为 f(x) = kx + b 。现在小 C 面前有 n 个线性函数 f_i(x) = k_ix+b_i ,他对这 n 个线性函数执行 m 次操作,每次可以:

  1. M i K B 代表把第 i 个线性函数改为 f_i(x) = kx + b
  2. Q l r x 返回 f_r(f_{r-1}(\ldots f_l(x))) \bmod 10 ^ 9 + 7

链接

BZOJ 4499

题解

两个线性函数 f_1(x) = k_1 x + b_1 f_2(x) = k_2 x + b_2 可以合并为 f(x) = f_2(f_1(x)) = k_2(k_1 x + b_1) + b_2 = k_1 k_2 x + k_2 b_1 + b_2

用线段树维护整个序列即可。

代码

#include <cstdio>

const int MAXN = 200000;
const long long MOD = 1e9 + 7;

struct LinearFunction {
    long long k, b;

    LinearFunction() {}
    LinearFunction(long long k, long long b) : k(k), b(b) {}

    static LinearFunction merge(const LinearFunction &l, const LinearFunction &r) {
        return LinearFunction(l.k * r.k % MOD, (r.k * l.b % MOD + r.b) % MOD);
    }

    long long operator()(long long x) {
        return (k * x % MOD + b) % MOD;
    }
};

struct SegT {
    int l, r, mid;
    SegT *lc, *rc;
    LinearFunction f;

    SegT(int l, int r, SegT *lc, SegT *rc, const LinearFunction &f) : l(l), r(r), mid(l + (r - l) / 2), lc(lc), rc(rc), f(f) {}
    SegT(int l, int r, SegT *lc, SegT *rc) : l(l), r(r), mid(l + (r - l) / 2), lc(lc), rc(rc), f(LinearFunction::merge(lc->f, rc->f)) {}

    void update(int pos, const LinearFunction &f) {
        if (l == r) this->f = f;
        else (pos <= mid ? lc : rc)->update(pos, f), this->f = LinearFunction::merge(lc->f, rc->f);
    }

    LinearFunction query(int l, int r) {
        if (l <= this->l && r >= this->r) return f;
        else if (r <= mid) return lc->query(l, r);
        else if (l > mid) return rc->query(l, r);
        else return LinearFunction::merge(lc->query(l, r), rc->query(l, r));
    }

    static SegT *build(int l, int r, LinearFunction *a) {
        if (l == r) return new SegT(l, r, NULL, NULL, a[l]);
        else {
            int mid = l + (r - l) / 2;
            return new SegT(l, r, build(l, mid, a), build(mid + 1, r, a));
        }
    }
} *seg;

int main() {
    int n, m;
    scanf("%d %d", &n, &m);

    static LinearFunction a[MAXN + 1];
    for (int i = 1; i <= n; i++) scanf("%lld %lld", &a[i].k, &a[i].b);

    seg = SegT::build(1, n, a);

    while (m--) {
        char s[2];
        scanf("%s", s);
        if (s[0] == 'M') {
            int i;
            LinearFunction f;
            scanf("%d %lld %lld", &i, &f.k, &f.b);
            seg->update(i, f);
        } else {
            int l, r;
            long long x;
            scanf("%d %d %lld", &l, &r, &x);
            LinearFunction f = seg->query(l, r);
            printf("%lld\n", f(x));
        }
    }

    return 0;
}