「省选模拟赛」小奇的集合 - 矩阵乘法

有一个大小为 的可重集 ,小奇每次操作可以加入一个数 均属于 ),求 次操作后它可获得的 的和的最大值(数据保证这个值为非负数)。

题解

每次要取的都是最大的和次大的,考虑最大的和次大的都是非负数正数的情况,使用矩阵乘法( 为最大值, 为次大值, 为当前总和):

对于次大值为负的情况,先模拟,直到为正再用矩阵乘法。

代码

#include <cstdio>
#include <climits>

const int MAXN = 100000;
const int MAXK = 1000000000;
const int MOD = 10000007;

template<typename T>
struct Matrix {
    T a[3][3];

    Matrix(T x1 = 1, T x2 = 0, T x3 = 0, T x4 = 0, T x5 = 1, T x6 = 0, T x7 = 0, T x8 = 0, T x9 = 1) {
        a[0][0] = x1, a[0][1] = x2, a[0][2] = x3;
        a[1][0] = x4, a[1][1] = x5, a[1][2] = x6;
        a[2][0] = x7, a[2][1] = x8, a[2][2] = x9;
    }

    const T &operator()(int i, int j) const {
        return a[i][j];
    }

    T &operator()(int i, int j) {
        return a[i][j];
    }

    void print() {
        for (int i = 0; i < 3; i++) {
            printf("[ ");
            for (int j = 0; j < 3; j++) {
                printf("%lld ", a[i][j]);
            }
            printf("] ");
        }
        putchar('\n');
    }
};

template <typename T>
Matrix<T> operator*(const Matrix<T> &a, const Matrix<T> &b) {
    Matrix<T> r(
        0, 0, 0,
        0, 0, 0,
        0, 0, 0
    );
    for (int i = 0; i < 3; i++) {
        for (int j = 0; j < 3; j++) {
            for (int k = 0; k < 3; k++) {
                r(i, j) = (r(i, j) + (a(i, k) * b(k, j)) % MOD) % MOD;
            }
        }
    }
    // r.print();
    return r;
}

template <typename T>
Matrix<T> pow(Matrix<T> m, int x) {
    Matrix<T> ans;
    /*for (int i = 0; i < x - 1; i++) ans = ans * m;
    return ans;*/

    for (; x; x >>= 1, m = m * m){
        // printf("x = %d\n", x);
        if (x & 1) ans = ans * m;
    }
    return ans;
}

int max = INT_MIN, max2 = INT_MIN, sum;

int main() {
    // freopen("set.in", "r", stdin);
    // freopen("set.out", "w", stdout);

    int n, k;
    scanf("%d %d", &n, &k);

    for (int i = 0; i < n; i++) {
        int num;
        scanf("%d", &num);

        sum = (sum + num) % MOD;
        if (num > max) max2 = max, max = num;
        else if (num > max2) max2 = num;
    }

    while (max2 < 0) {
        k--;
        int num = max + max2;
        sum = (sum + num) % MOD;
        if (num > max) max2 = max, max = num;
        else if (num > max2) max2 = num;
    }

    Matrix<long long> base(
        1, 1, 0,
        1, 0, 0,
        1, 1, 1
    );

    Matrix<long long> init(
        max % MOD, 0, 0,
        max2 % MOD, 0, 0,
        sum % MOD, 0, 0
    );

    Matrix<long long> result = pow(base, k) * init;

    // printf("%d\n", (sum + (result(2, 0) * max) % MOD + (result(2, 1) * max2) % MOD) % MOD);
    printf("%d\n", (result(2, 0) + MOD) % MOD);

    fclose(stdin);
    fclose(stdout);

    return 0;
}