[BZOJ3772]精神污染
试题描述
输入
输出
所求的概率,以最简分数形式输出。
输入示例
5 3 1 2 2 3 3 4 2 5 3 5 2 5 1 4
输出示例
1/3
数据规模及约定
100%的数据满足:N,M<=100000
题解
先膜拜一下 popoqqq 的题解。
然后我再试图自己讲讲。。。
一个很重要东西:一棵树上,一条链 A 被另一条链 B 包含当且仅当 A 的两个端点都在 B 上。
那么对于一条链,要知道它是否被包含,只需要关心它的两个端点就行了。
对于一条链 (a, b),我们不妨把 b 称为 a 的对面儿。查询一条链 C 包含了哪些链的方法就是把所有在链 C 上的节点拎出来,分别统计它们的对面儿在不在链 C 上,而这个在链 C 上的“对面儿”们的个数就是 C 包含的链的条数。
所以我们不妨建立一个树形的主席树,每个节点 u 在其父节点的基础之上把 u 的“对面儿”们的信息加进来。
这个信息,具体指什么信息呢?我们发现我们需要回答有几个点在一条链上,所以可以对整棵树搞一个括号序列,那么这个信息,存节点在括号序列对应的位置就好了,注意一个节点在括号序列中对应两个位置,左括号的位置 +1,右括号的位置 -1。
注意括号序列只能处理自己到祖先这样的链,所以对于一条 (a, b) 的链,在询问它包含几条链时把它拆成 (a, lca(a, b)) 和 (b, lca(a, b)) 这两条链就好了,小心 lca(a, b) 这个点有可能被重复统计,写代码时细心一点。
#include <iostream>
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <stack>
#include <vector>
#include <queue>
#include <cstring>
#include <string>
#include <map>
#include <set>
using namespace std;
const int BufferSize = 1 << 16;
char buffer[BufferSize], *Head, *Tail;
inline char Getchar() {
	if(Head == Tail) {
		int l = fread(buffer, 1, BufferSize, stdin);
		Tail = (Head = buffer) + l;
	}
	return *Head++;
}
int read() {
	int x = 0, f = 1; char c = Getchar();
	while(!isdigit(c)){ if(c == ‘-‘) f = -1; c = Getchar(); }
	while(isdigit(c)){ x = x * 10 + c - ‘0‘; c = Getchar(); }
	return x * f;
}
#define maxn 100010
#define maxm 200010
#define maxnode 4000010
#define maxlog 17
#define LL long long
int ToT, sumv[maxnode], lc[maxnode], rc[maxnode];
void update(int& y, int x, int l, int r, int p, int v) {
	sumv[y = ++ToT] = sumv[x] + v;
	if(l == r) return ;
	int mid = l + r >> 1; lc[y] = lc[x]; rc[y] = rc[x];
	if(p <= mid) update(lc[y], lc[x], l, mid, p, v);
	else update(rc[y], rc[x], mid + 1, r, p, v);
	return ;
}
int query(int o, int l, int r, int ql, int qr) {
	if(!o) return 0;
	if(ql <= l && r <= qr) return sumv[o];
	int mid = l + r >> 1, ans = 0;
	if(ql <= mid) ans += query(lc[o], l, mid, ql, qr);
	if(qr > mid) ans += query(rc[o], mid + 1, r, ql, qr);
	return ans;
}
int n, m, head[maxn], next[maxm], to[maxm];
void AddEdge(int a, int b) {
	to[++m] = b; next[m] = head[a]; head[a] = m;
	swap(a, b);
	to[++m] = b; next[m] = head[a]; head[a] = m;
	return ;
}
struct Path {
	int a, b;
	Path() {}
	Path(int _, int __): a(_), b(__) {}
} ps[maxn];
int pm, phead[maxn], pnxt[maxn], pto[maxn];
void AddPath(int a, int b) {
	ps[++pm] = Path(a, b);
	pto[pm] = b; pnxt[pm] = phead[a]; phead[a] = pm;
	return ;
}
int fa[maxlog][maxn], dep[maxn], rt[maxn], dl[maxn], dr[maxn], clo;
void build(int u) {
	dl[u] = ++clo;
	for(int i = 1; i < maxlog; i++) fa[i][u] = fa[i-1][fa[i-1][u]];
	for(int e = head[u]; e; e = next[e]) if(to[e] != fa[0][u]) {
		dep[to[e]] = dep[u] + 1;
		fa[0][to[e]] = u;
		build(to[e]);
	}
	dr[u] = ++clo;
	return ;
}
int lca(int a, int b) {
	if(dep[a] < dep[b]) swap(a, b);
	for(int i = maxlog - 1; i >= 0; i--) if(dep[a] - dep[b] >= (1 << i)) a = fa[i][a];
	for(int i = maxlog - 1; i >= 0; i--) if(fa[i][a] != fa[i][b]) a = fa[i][a], b = fa[i][b];
	return a == b ? a : fa[0][b];
}
void build2(int u) {
	rt[u] = rt[fa[0][u]];
	for(int e = phead[u]; e; e = pnxt[e]) {
		update(rt[u], rt[u], 1, clo, dl[pto[e]], 1),
		update(rt[u], rt[u], 1, clo, dr[pto[e]], -1);
//		if(ToT % 10000 == 0) printf("ToT: %d\n", ToT);
	}
	for(int e = head[u]; e; e = next[e]) if(to[e] != fa[0][u]) build2(to[e]);
	return ;
}
LL gcd(LL a, LL b) { return b ? gcd(b, a % b) : a; }
int main() {
//	freopen("data.in", "r", stdin);
	n = read(); int tm = read();
	for(int i = 1; i < n; i++) {
		int a = read(), b = read();
		AddEdge(a, b);
	}
	while(tm--) {
		int a = read(), b = read();
		AddPath(a, b);
	}
	
	build(1); build2(1);
	LL ans = 0;
	for(int i = 1; i <= pm; i++) {
		int a = ps[i].a, b = ps[i].b, c = lca(a, b);
//		printf("lca(%d, %d) = %d\n", a, b, c);
		ans += (LL)query(rt[a], 1, clo, dl[c], dl[a]) - query(rt[fa[0][c]], 1, clo, dl[c], dl[a]);
		ans += (LL)query(rt[a], 1, clo, dl[c] + 1, dl[b]) - query(rt[fa[0][c]], 1, clo, dl[c] + 1, dl[b]);
		ans += (LL)query(rt[b], 1, clo, dl[c], dl[a]) - query(rt[c], 1, clo, dl[c], dl[a]);
		ans += (LL)query(rt[b], 1, clo, dl[c] + 1, dl[b]) - query(rt[c], 1, clo, dl[c] + 1, dl[b]);
		ans--;
	}
	
	LL bns = (LL)pm * (pm - 1) >> 1, g = gcd(ans, bns);
	ans /= g; bns /= g;
	printf("%lld/%lld\n", ans, bns);
	
	return 0;
}
原文:http://www.cnblogs.com/xiao-ju-ruo-xjr/p/6358722.html