「UVa 11174」Stand in a Line - 计数原理 + 乘法逆元

分别属于多个家族的 N N \leq 40000 )个人想要站成一排,但没有人想站在他爹前面,求方案总数。

链接

UVa 11174

题解

题目描述的父子关系组成了一个森林,每个家族对应一棵树,我们可以添加一个虚拟节点,让它成为每个家族祖先的「父亲」,然后求出整棵树的方案数即为答案。

f(i) 表示以 i 为根的子树的方案数,用 s(i) 表示以 i 为根的子树的大小。

首先, i 的几棵子树各自的排列方案是相互独立的,可以使用乘法原理来计算。

还要注意一点, i 的几棵子树的节点是可以穿插排列的,即只需要保证每对父子的相对位置而整体是无序的。如果对所有子树的节点做全排列,那么改变了原有的父子相对位置的排列方案都是无效的,即同一种穿插方式中多次改变父子相对位置只有一次有效。我们可以把同一个子树的节点看做相同的,然后做有重复元素的全排列,就可以得到正确结果,因为每一种排列都对应且只对应一种顺序。

「子树各自的排列方案」和「节点的穿插顺序」是相互独立的,因此要用乘法原理。

c(i) 表示 i 的子节点的集合,则递归计算的公式为:

f(i) = \prod\limits_{j \in c(i)}f(j) * \frac{s(i) - 1}{\prod\limits_{j \in c(i)}{s(i)}}

公式中使用了除法,模意义下的除以一个数等于乘以这个数的乘法逆元,所以要在程序开始时递推预处理出所有数的阶乘,并求出对应的乘法逆元。

代码

#include <cstdio>
#include <cstring>

const long long MOD = 1000000007;
const int MAXN = 40000;

struct Tree {
    Tree *firstChild, *parent, *next;
    int childCount, size;
    long long ans;

    long long solve();
} trees[MAXN + 1];

long long fac[MAXN + 1], facInverse[MAXN + 1];

void exgcd(long long a, long long b, long long &g, long long &x, long long &y) {
    if (b == 0) {
        g = a;
        x = 1, y = 0;
    } else {
        exgcd(b, a % b, g, y, x);
        y -= x * (a / b);
    }
}

long long inverse(long long num) {
    long long g, x, y;
    exgcd(num, MOD, g, x, y);
    return (x % MOD + MOD) % MOD;
}

void makeTable() {
    fac[0] = 1;
    for (int i = 1; i <= MAXN; i++) {
        fac[i] = fac[i - 1] * i % MOD;
        facInverse[i] = inverse(fac[i]);
    }
}

void addRelation(int child, int parent) {
    trees[child].parent = &trees[parent];
    trees[child].next = trees[parent].firstChild;
    trees[parent].firstChild = &trees[child];
    trees[parent].childCount++;
}

void cleanUp(int n) {
    memset(trees, 0, sizeof(Tree) * (n + 1));
}

long long Tree::solve() {
    for (Tree *c = firstChild; c; c = c->next) {
        c->solve();
        size += c->size;
    }
    size++;

    ans = fac[size - 1];
    for (Tree *c = firstChild; c; c = c->next) {
        ans = ans * c->ans % MOD;
        ans = ans * facInverse[c->size] % MOD;
    }
    return ans;
}

int main() {
    makeTable();

    int t;
    scanf("%d", &t);
    for (int i = 0; i < t; i++) {
        int n, m;
        scanf("%d %d", &n, &m);

        for (int i = 0; i < m; i++) {
            int child, parent;
            scanf("%d %d", &child, &parent);

            addRelation(child, parent);
        }

        for (int i = 1; i <= n; i++) {
            if (trees[i].parent == NULL) addRelation(i, 0);
        }

        printf("%lld\n", trees[0].solve());
        cleanUp(n);
    }

    return 0;
}