「BZOJ 1706」乳牛接力跑 - 矩阵乘法

给一个图,求从 点到 点恰好经过 步的最短路。

链接

BZOJ 1586
COGS 1470

题解

矩阵乘法可以求出恰好经过 步的路径数,将矩阵乘法中的求和改为取较小值,相乘改为相加即可求出恰好经过 步的最短路。

注意这时的单位矩阵,对角线上为 ,其它全部为正无穷。

代码

#include <cstdio>
#include <climits>
#include <algorithm>

const int MAXN = 1000;
const int MAXM = 100;
const int MAXN_SMALL = 200;
const int MAXK = 1000000;

struct Matrix {
    int a[MAXN_SMALL][MAXN_SMALL];

    Matrix(const bool unit = false) {
        for (int i = 0; i < MAXN_SMALL; i++) for (int j = 0; j < MAXN_SMALL; j++) a[i][j] = INT_MAX;
        if (unit) for (int i = 0; i < MAXN_SMALL; i++) a[i][i] = 0;
    }

    int &operator()(const int i, const int j) { return a[i][j]; }
    const int &operator()(const int i, const int j) const { return a[i][j]; }
};

Matrix operator*(const Matrix &a, const Matrix &b) {
    Matrix res(false);
    for (int i = 0; i < MAXN_SMALL; i++) for (int j = 0; j < MAXN_SMALL; j++) for (int k = 0; k < MAXN_SMALL; k++) if (a(i, k) != INT_MAX && b(k, j) != INT_MAX) res(i, j) = std::min(res(i, j), a(i, k) + b(k, j));
    return res;
}

Matrix pow(Matrix a, int n) {
    Matrix res(true);
    for (; n; n >>= 1, a = a * a) if (n & 1) res = res * a;
    return res;
}

int main() {
    freopen("relays.in", "r", stdin);
    freopen("relays.out", "w", stdout);

    int n, m, s, t, cnt;
    static int set[MAXM * 2];
    scanf("%d %d %d %d", &n, &m, &s, &t);

    static struct {
        int u, v, w;
    } E[MAXM];

    for (int i = 0; i < m; i++) {
        scanf("%d %d %d", &E[i].w, &E[i].u, &E[i].v);
        set[i] = E[i].u;
        set[i + m] = E[i].v;
    }

    std::sort(set, set + m * 2);
    int *end = std::unique(set, set + m * 2);
    cnt = end - set;

    Matrix g(false);
    for (int i = 0; i < m; i++) {
        E[i].u = std::lower_bound(set, end, E[i].u) - set;
        E[i].v = std::lower_bound(set, end, E[i].v) - set;
        g(E[i].u, E[i].v) = g(E[i].v, E[i].u) = E[i].w;
    }
    s = std::lower_bound(set, end, s) - set;
    t = std::lower_bound(set, end, t) - set;

    Matrix res = pow(g, n);
    /*
    Matrix res(true);
    for (int i = 0; i < n; i++) {
        res = res * g;
        for (int i = 0; i < cnt; i++) for (int j = 0; j < cnt; j++) printf("%d%c", res(i, j), j == cnt - 1 ? '\n' : ' ');
        puts("-------------------");
    }
    */

    printf("%d\n", res(s, t));

    fclose(stdin);
    fclose(stdout);

    return 0;
}