题目大意就是有 n(n≤30) 个点的无向完全图,有 m(m≤2n×(n−1)) 条道路上没有怪兽,其他道路都有怪兽。
一个人一开始在 1 号点,每次会随机选择一条路走并把这条路上的怪兽全部杀完。
问期望走多少步才能让这 n 个点之间都存在没有怪兽的路径。
按照国际惯例,多组数据,T≤100。
一看网上的题解,为什么都是 O(2n) 的啊,为什么暴力都能过啊。
当然先声明一下,下面说的复杂度多不太准确,比如说 O(2n),实际上可能说的是O(2n×n) 或是 O(2n×n2),但考虑到这道题对于这种算法的复杂度数据范围放的很宽,且n 很小,这里就忽略不计了,反正这是实现的问题。
然后开始自闭。
结果发现 udebug 上那份 std
好像是可以过 30 0
这组数据的。
于是就大概搞了个复杂度似乎有点真的算法。
先说一下那个 O(2n) 的做法吧,我们先考虑把联通的点缩起来,我们姑且称之为团,这样问题就转化成了期望需要走多少次才能遍历所有的团。
然后考虑令 fi,j 表示现在在第 i 个团,选取团的状态为 j,直接dp
即可,如果直接计搜然后用 hash
或是 map
存状态,这道题就直接过掉了。
然后我们发现其实我们只关心每个团的大小而不关心每个团的具体标号,于是我们可以用 fi,j 表示当前在第 i 个点,还有 j(j 是一个 multiset
,当然也可以把这个multiset
哈希掉)这些团的期望,然后转移。
转移的话大概就是枚举下一步到那个团,直接计搜下去即可,具体可以看代码。
复杂度?我们发现状态大概是把 n 划分成多个数字之和的方案数,也就是 划分数 ,n=30 时才5000。
大概跑的是比较快的,极限数据(100个 30 0
)大概只需要跑0.3 秒(本机)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
| #include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 305;
int n, m;
struct ZT { static const int mod = 19260817;
multiset<int> st; int hsh, xx; inline void init() { hsh = xx; for (auto x : st) { hsh = ((LL) hsh * hsh % mod * hsh % mod + x) % mod; } }
ZT () { hsh = 0, xx = 0; st.clear(); } };
int siz[maxn]; int fa[maxn];
inline int getfa(int x) { return fa[x] == x ? x : fa[x] = getfa(fa[x]); }
inline void merge(int x, int y) { int fax = getfa(x); int fay = getfa(y); if (fax != fay) { if (fax > fay) { swap(fax, fay); } siz[fax] += siz[fay]; fa[fay] = fax; } }
map<int, double> mp;
inline double dfs(const ZT& now) { if (now.st.empty()) { return 0; } if (mp.count(now.hsh)) { return mp[now.hsh]; } double ans = 0; for (auto x : now.st) { ZT tmp = now; auto it = tmp.st.find(x); tmp.xx += x; tmp.st.erase(it); tmp.init(); ans += dfs(tmp) * x; } ans /= n - 1; ans++; ans /= 1. - (double) (now.xx - 1) / (double) (n - 1); return mp[now.hsh] = ans; }
inline double solve() { mp.clear(); read(n), read(m); for (int i = 1; i <= n; ++i) { fa[i] = i; siz[i] = 1; } for (int i = 1; i <= m; ++i) { int x, y; read(x), read(y); merge(x, y); } ZT fir; fir.xx = siz[1]; for (int i = 2; i <= n; ++i) { if (fa[i] == i) { fir.st.insert(siz[i]); } } fir.init(); return dfs(fir); }
int main() { int T; read(T); for (int i = 1; i <= T; ++i) { printf("Case %d: %.10f\n", i, solve()); } return 0; }
|