既然开始水博客了,就多水点吧……
你有一颗仙人掌(每两个环至多有一个交点)形状的电路,一条边的电阻是1。你需要求出所有点对之间的电阻的和,对1e9+7取模。
第一眼看到“仙人掌”,只能想到某些仙人掌数据结构。再仔细一看,才发现挺妙的……
OCR实在是太强大了,我直接复制了出题人的题解,反正就是这么做的。注意dfs树的性质,所有不是树边的边都是返祖边,可以写起来很简单。据说fft可能tle,但是我还是一遍过了,可能是kactl的模版实在太快了吧,牛逼。。
出题人的题解如下
考虑两点之间的电阻,首先你把路径拿出来,可以看成若干条树边和环串联起来,也就是相加。所以考虑也就是对于每个树边和环分别考虑对答案的贡献。
每个树边的贡献为通过树边的点对,也就是两边大小相乘。对于环边,假设两个点下面的子树大小分别 为 和 环长为l,距离为 那么等效电阻为 与 并联,也就是 共有 对,可以将式子展开之后使用前缀和之类的方法计算,使用FFT可能会TLE。。
代码如下
#include <bits/stdc++.h>
using namespace std;
template <typename A, typename B>string to_string(pair<A, B> p);template <typename A, typename B, typename C>string to_string(tuple<A, B, C> p);template <typename A, typename B, typename C, typename D>string to_string(tuple<A, B, C, D> p);string to_string(const string& s) { return '"' + s + '"';}string to_string(const char* s) { return to_string((string) s);}string to_string(bool b) { return (b ? "true" : "false");}string to_string(vector<bool> v) { bool first = true; string res = "{"; for (int i = 0; i < static_cast<int>(v.size()); i++) { if (!first) { res += ", "; } first = false; res += to_string(v[i]); } res += "}"; return res;}template <size_t N>string to_string(bitset<N> v) { string res = ""; for (size_t i = 0; i < N; i++) { res += static_cast<char>('0' + v[i]); } return res;}template <typename A>string to_string(A v) { bool first = true; string res = "{"; for (const auto &x : v) { if (!first) { res += ", "; } first = false; res += to_string(x); } res += "}"; return res;}template <typename A, typename B>string to_string(pair<A, B> p) { return "(" + to_string(p.first) + ", " + to_string(p.second) + ")";}template <typename A, typename B, typename C>string to_string(tuple<A, B, C> p) { return "(" + to_string(get<0>(p)) + ", " + to_string(get<1>(p)) + ", " + to_string(get<2>(p)) + ")";}template <typename A, typename B, typename C, typename D>string to_string(tuple<A, B, C, D> p) { return "(" + to_string(get<0>(p)) + ", " + to_string(get<1>(p)) + ", " + to_string(get<2>(p)) + ", " + to_string(get<3>(p)) + ")";}void debug_out() { cerr << endl; }template <typename Head, typename... Tail>void debug_out(Head H, Tail... T) { cerr << " " << to_string(H); debug_out(T...);}
#ifdef LOCAL
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
#else
#define debug(...) 42
#endif
#define set0(x) memset(x,0,sizeof(x))
#define F first
#define S second
#define PB push_back
#define MP make_pair
#define rep(i, a, b) for(int i = a; i < (b); ++i)
#define trav(a, x) for(auto& a : x)
#define all(x) x.begin(), x.end()
#define sz(x) (int)(x).size()
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<ll,ll> pll;
template<typename T> void read(T &x){
x = 0;char ch = getchar();ll f = 1;
while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}
template<typename T, typename... Args> void read(T &first, Args& ... args) {
read(first);
read(args...);
}
const long double M_PIl = acosl(-1);
typedef complex<double> C;
typedef vector<double> vd;
typedef vector<int> vi;
void fft(vector<C>& a) {
int n = sz(a), L = 31 - __builtin_clz(n);
static vector<complex<long double>> R(2, 1);
static vector<C> rt(2, 1); // (^ 10% faster if double)
for (static int k = 2; k < n; k *= 2) {
R.resize(n); rt.resize(n);
auto x = polar(1.0L, M_PIl / k); // M_PI, lower-case L
rep(i,k,2*k) rt[i] = R[i] = i&1 ? R[i/2] * x : R[i/2];
}
vi rev(n);
rep(i,0,n) rev[i] = (rev[i / 2] | (i & 1) << L) / 2;
rep(i,0,n) if (i < rev[i]) swap(a[i], a[rev[i]]);
for (int k = 1; k < n; k *= 2)
for (int i = 0; i < n; i += 2 * k) rep(j,0,k) {
// C z = rt[j+k] * a[i+j+k]; // (25% faster if hand-rolled) /// include-line
auto x = (double *)&rt[j+k], y = (double *)&a[i+j+k]; /// exclude-line
C z(x[0]*y[0] - x[1]*y[1], x[0]*y[1] + x[1]*y[0]); /// exclude-line
a[i + j + k] = a[i + j] - z;
a[i + j] += z;
}
}
vd conv(const vd& a, const vd& b) {
if (a.empty() || b.empty()) return {};
vd res(sz(a) + sz(b) - 1);
int L = 32 - __builtin_clz(sz(res)), n = 1 << L;
vector<C> in(n), out(n);
copy(all(a), begin(in));
rep(i,0,sz(b)) in[i].imag(b[i]);
fft(in);
trav(x, in) x *= x;
rep(i,0,n) out[i] = in[-i & (n - 1)] - conj(in[i]);
fft(out);
rep(i,0,sz(res)) res[i] = imag(out[i]) / (4 * n);
return res;
}
const int mod = 1000000007;
inline int mul(int x,int y){return 1ll*x*y%mod;}
inline int add(int x,int y){return x+y>=mod?x+y-mod:x+y;}
inline int sub(int x,int y){return x-y<0?x-y+mod:x-y;}
inline int sq(int x){return 1ll*x*x%mod;}
int mpow(int a,int b){return b == 0 ? 1 : ( b&1 ? mul(a,sq(mpow(a,b/2))) : sq(mpow(a,b/2)));}
int T,n,m,u,v,vis[200020],depth[200020],sz[200020],fa[200020],ok[200020];
vector<int> G[200020];
int ans = 0;
void dfs(int num){
debug(num);
vis[num] = 1;
sz[num] = 1;
for(auto ct:G[num]){
if(!vis[ct]){
depth[ct] = depth[num]+1;
fa[ct] = num;
dfs(ct);
sz[num]+=sz[ct];
}else{
if(depth[ct] == depth[num]-1 || depth[ct] > depth[num]+1)continue;
debug(num,ct);
int p = num;
while(p!=ct){
debug(p);
ok[p] = 0;
p = fa[p];
}
}
}
}
void dfs2(int num){
debug(num);
if(ok[num] && num!=1){
ans = add(ans,mul(sz[num],n-sz[num]));
}
for(auto ct:G[num]){
if(depth[ct] == depth[num]+1)dfs2(ct);
else{
if(depth[ct] == depth[num]-1 || depth[ct] > depth[num]+1)continue;
debug(num,ct,depth[num],depth[ct]);
int circsize = depth[num]-depth[ct]+1;
vector<int> val;
val.PB(sz[num]);
int v1 = fa[num],v2 = num;
while(v2!=ct){
val.PB(sz[v1]-sz[v2]);
v2 = v1;
v1 = fa[v1];
}
val[val.size()-1]+=n-sz[ct];
vd a,b;
for(auto ct:val)a.PB(ct);
b = a;
reverse(all(b));
a = conv(a,b);
int cv = 0;
for(int i=0;i<a.size();i++){
ll c = ((ll)(a[i]+0.5))%mod;
int d = abs(i-(circsize-1));
cv = add(cv,mul(c,mul(d,circsize-d)));
}
cv = mul(cv,(mod+1)/2);
cv = mul(cv,mpow(circsize,mod-2));
ans = add(ans,cv);
}
}
}
int main() {
read(T);
while(T--){
ans = 0;
read(n,m);
fa[1] = 1;
for(int i=1;i<=n;i++){
vis[i] = 0;
depth[i] = 0;
ok[i] = 1;
G[i].clear();
}
for(int i=1;i<=m;i++){
read(u,v);
G[u].PB(v);
G[v].PB(u);
}
debug(1);
dfs(1);
debug(1);
dfs2(1);
cout<<ans<<endl;
}
return 0;
}
发表回复