「省选模拟赛」完美理论 - 最大权闭合图

有两棵点集相同的树,顶点编号为 ,每个点都有一个权值,你需要选择一个点集的子集,使得这个子集在两棵树上都是一个连通块。你要选出权值和最大的子集,你只需要输出最大的权值和。

题解

首先,我们可以枚举一个点,使得这个点在选择的连通块中,限制转化为:选择了某个点,就必须选择它的父节点。

于是这就是经典的最大权闭合图问题,用最小割模型解决即可。

40分代码(暴搜 + 树形DP)

#include <cstdio>
#include <algorithm>
#include <queue>
#include <utility>

const int MAXN = 100;
const int MAXT = 50;

struct Node;
struct Edge;

struct Node {
    Edge *e;
    int id, depth;
    bool visited;

    struct TreeDPAnswer {
        bool calced;
        int val;
    } ans[2];
} N[2][MAXN];

struct Edge {
    Node *t;
    Edge *next;

    Edge(Node *s, Node *t) : t(t), next(s->e) {}
};

inline void addEdge(const int i, const int u, const int v) {
    N[i][u].e = new Edge(&N[i][u], &N[i][v]);
    N[i][v].e = new Edge(&N[i][v], &N[i][u]);
}

int n, a[MAXN], f[MAXN], ans;
std::pair<int, int> E[2][MAXN - 1];

inline int check() {
    int start = -1, tot = 0, sum = 0;
    for (int i = 0; i < n; i++) if (f[i]) start = i, tot++, N[0][i].visited = N[1][i].visited = false, sum += a[i];

    for (int i = 0; i < 2; i++) {
        std::queue<Node *> q;
        q.push(&N[i][start]);
        N[i][start].visited = true;

        int cnt = 1;
        while (!q.empty()) {
            Node *v = q.front();
            q.pop();

            for (Edge *e = v->e; e; e = e->next) if (f[e->t->id] && !e->t->visited) e->t->visited = true, q.push(e->t), cnt++;
        }

        // printf("cnt = %d, tot = %d, sum = %d\n", cnt, tot, sum);
        if (cnt < tot) return -1;
    }

    return sum;
}

inline void search(int i = 0) {
    if (i == n) {
        ans = std::max(ans, check());
        return;
    }

    f[i] = true;
    search(i + 1);
    f[i] = false;
    search(i + 1);
}

inline void bfsDepth() {
    std::queue<Node *> q;
    q.push(&N[0][0]);
    N[0][0].depth = 1;
    while (!q.empty()) {
        Node *v = q.front();
        q.pop();

        for (Edge *e = v->e; e; e = e->next) if (!e->t->depth) e->t->depth = v->depth + 1, q.push(e->t);
    }
}

int dp(Node *v, const bool flag) {
    int &ans = v->ans[flag].val;
    if (v->ans[flag].calced) return ans;
    v->ans[flag].calced = true;

    ans = 0;
    if (flag) {
        int sum = 0;
        for (Edge *e = v->e; e; e = e->next) if (e->t->depth == v->depth + 1) {
            // ans = std::max(ans, dp(e->t, false));
            int t = dp(e->t, true);
            if (t > 0) sum += t;
        }
        sum += a[v->id];
        ans = std::max(ans, sum);
    } else {
        for (Edge *e = v->e; e; e = e->next) if (e->t->depth == v->depth + 1) {
            ans = std::max(ans, dp(e->t, false));
        }
        ans = std::max(ans, dp(v, true));
    }
    // printf("dp(%d[with `w` = %d], %d) = %d\n", v->id, a[v->id], flag ? 1 : 0, ans);
    return ans;
}

inline int solve(const bool useTreeDP) {
    if (!useTreeDP) {
        ans = 0;
        search();
        return ans;
    } else {
        bfsDepth();
        return dp(&N[0][0], false);
    }
}

inline void cleanUp() {
    for (int i = 0; i < 2; i++) for (int j = 0; j < n; j++) {
        N[i][j].id = j;
        Edge *next;
        for (Edge *&e = N[i][j].e; e; next = e->next, delete e, e = next);

        N[i][j].depth = 0;
        N[i][j].ans[0].calced = N[i][j].ans[1].calced = false;
        N[i][j].ans[0].val = N[i][j].ans[1].val = 0;
    }
    std::fill(a, a + n, 0);
}

int main() {
    freopen("theory.in", "r", stdin);
    freopen("theory.out", "w", stdout);

    int t;
    scanf("%d", &t);
    while(t--) {
        scanf("%d", &n);
        cleanUp();
        for (int i = 0; i < n; i++) scanf("%d", &a[i]);
        for (int i = 0; i < 2; i++) for (int j = 0, u, v; j < n - 1; j++) {
            scanf("%d %d", &u, &v), u--, v--;
            E[i][j] = std::make_pair(std::min(u, v), std::max(u, v));
            addEdge(i, u, v);
        }
        std::sort(E[0], E[0] + n - 1);
        std::sort(E[1], E[1] + n - 1);

        bool useTreeDP = std::equal(E[0], E[0] + n - 1, E[1]);

        printf("%d\n", solve(useTreeDP));
    }

    fclose(stdin);
    fclose(stdout);

    return 0;
}

代码

#include <cstdio>
#include <climits>
#include <algorithm>
#include <queue>
#include <vector>

const int MAXN = 100;
const int MAXT = 50;

struct Node;
struct Edge;

struct Node {
    Edge *e, *c;
    int l;
} N[MAXN + 2];

struct Edge {
    Node *t;
    int f, c;
    Edge *next, *r;

    Edge(Node *s, Node *t, const int c) : t(t), f(0), c(c), next(s->e) {}
};

struct Dinic {
    bool makeLevelGraph(Node *s, Node *t, const int n) {
        for (int i = 0; i < n; i++) N[i].c = N[i].e, N[i].l = 0;

        std::queue<Node *> q;
        q.push(s);
        s->l = 1;

        while (!q.empty()) {
            Node *v = q.front();
            q.pop();

            for (Edge *e = v->e; e; e = e->next) {
                if (e->f < e->c && e->t->l == 0) {
                    e->t->l = v->l + 1;
                    if (e->t == t) return true;
                    else q.push(e->t);
                }
            }
        }

        return false;
    }

    int findPath(Node *s, Node *t, const int limit = INT_MAX) {
        if (s == t) return limit;
        for (Edge *&e = s->c; e; e = e->next) {
            if (e->f < e->c && e->t->l == s->l + 1) {
                int f = findPath(e->t, t, std::min(e->c - e->f, limit));
                if (f > 0) {
                    e->f += f, e->r->f -= f;
                    return f;
                }
            }
        }

        return 0;
    }

    int operator()(const int s, const int t, const int n) {
        int ans = 0;
        while (makeLevelGraph(&N[s], &N[t], n)) {
            int f;
            while ((f = findPath(&N[s], &N[t])) > 0) ans += f;
        }

        return ans;
    }
} dinic;

inline void addEdge(const int s, const int t, const int c) {
    N[s].e = new Edge(&N[s], &N[t], c);
    N[t].e = new Edge(&N[t], &N[s], 0);
    N[s].e->r = N[t].e, N[t].e->r = N[s].e;
}

struct TreeNode;
struct TreeEdge;

int n, a[MAXN];

struct TreeNode {
    TreeEdge *e;
    int id;
    bool visited;

    int w() const { return a[id]; }
} T[2][MAXN];

struct TreeEdge {
    TreeNode *t;
    TreeEdge *next;

    TreeEdge(TreeNode *s, TreeNode *t) : t(t), next(s->e) {}
};

inline void addTreeEdge(const int x, const int u, const int v) {
    T[x][u].e = new TreeEdge(&T[x][u], &T[x][v]);
    T[x][v].e = new TreeEdge(&T[x][v], &T[x][u]);
}

inline void cleanUpNetwork() {
    for (int i = 0; i < n + 2; i++) {
        Edge *next;
        for (Edge *&e = N[i].e; e; next = e->next, delete e, e = next);
    }
}

inline int solve(const int root) {
    const int s = 0, t = n + 1;
    int sum = 0;

    for (int i = 0; i < n; i++) {
        if (a[i] > 0) addEdge(s, i + 1, a[i]), sum += a[i];
        else addEdge(i + 1, t, -a[i]);
    }

    for (int i = 0; i < 2; i++) {
        for (int j = 0; j < n; j++) T[i][j].visited = (j == root) ? true : false;

        std::queue<TreeNode *> q;
        q.push(&T[i][root]);

        while (!q.empty()) {
            TreeNode *v = q.front();
            q.pop();

            for (TreeEdge *e = v->e; e; e = e->next) {
                if (!e->t->visited) {
                    e->t->visited = true;
                    addEdge(e->t->id + 1, v->id + 1, INT_MAX);
                    q.push(e->t);
                }
            }
        }
    }

    int ans = sum - dinic(s, t, n + 2);
    cleanUpNetwork();

    return ans;
}

inline void cleanUp() {
    for (int i = 0; i < 2; i++) for (int j = 0; j < n; j++) {
        TreeEdge *next;
        for (TreeEdge *&e = T[i][j].e; e; next = e->next, delete e, e = next);
    }
}

int main() {
    freopen("theory.in", "r", stdin);
    // freopen("theory.out", "w", stdout);

    int t;
    scanf("%d", &t);
    while (t--) {
        scanf("%d", &n);
        for (int i = 0; i < n; i++) scanf("%d", &a[i]), T[0][i].id = T[1][i].id = i;

        for (int i = 0; i < 2; i++) for (int j = 0, u, v; j < n - 1; j++) {
            scanf("%d %d", &u, &v), u--, v--;
            // printf("(%d, %d)\n", u, v);
            addTreeEdge(i, u, v);
        }

        int ans = 0;
        for (int i = 0; i < n; i++) ans = std::max(ans, solve(i));
        printf("%d\n", ans);

        cleanUp();
    }

    fclose(stdin);
    fclose(stdout);

    return 0;
}