题目链接:洛谷改了名字和题面的描述我也不知道为什么
题目描述:
前置芝士:
线段树合并
题解:
首先a是固定的。
然后我们来讨论b的位置。因为c是a和b公共的后代,所以a和b的最近公共祖先一定是a或b(即它们在一条链上)。
先假设b是a的祖先,则c只需是a的后代即可,很简单。
若a是b的祖先,设dep[a]代表节点a的深度,则$1 \leq dep[b]-dep[a] \leq k$,b需要满足$dep[a]+1 \leq dep[b] \leq dep[a]+k$
我们只用找出所有这样的b的后代数量的和即可。
看到区间查询问题很容易就可以想到用线段树。每个节点都维护一棵线段树,树上节点[l,r]代表所有深度在[l,r]之间的节点的后代的数量的和。
暴力枚举每个后代维护线段树的复杂度是$O(n^2\log{n})$的,显然是不行的。
我们发现x的线段树和它的子节点y的线段树有一部分是相同的,其实y的线段树应是x的一部分,这样我们就可以用线段树合并维护了。复杂度是$O(n\log{n})$。
1 #include <iostream> 2 #include <cstdio> 3 using namespace std; 4 typedef long long ll; 5 6 const ll N = 600010; 7 ll lson[N * 20], rson[N * 20], val[N * 20], num; 8 struct node{ 9 ll pre, to; 10 }edge[N << 1]; 11 ll head[N], tot; 12 ll n, q; 13 ll dep[N], sz[N], rt[N]; 14 void ins(ll &cur, ll l, ll r, ll pos, ll v) { 15 if (!cur) cur = ++num; 16 if (l == r) { 17 val[cur] += v; 18 return; 19 } 20 ll mid = (l + r) >> 1; 21 if (pos <= mid) ins(lson[cur], l, mid, pos, v); 22 else ins(rson[cur], mid + 1, r, pos, v); 23 val[cur] = val[lson[cur]] + val[rson[cur]]; 24 } 25 void add(ll u, ll v) { 26 edge[++tot] = node{head[u], v}; 27 head[u] = tot; 28 } 29 void dfs(ll x, ll fa) { 30 sz[x] = 1; 31 for (ll i = head[x]; i; i = edge[i].pre) { 32 ll y = edge[i].to; 33 if (y == fa) continue; 34 dep[y] = dep[x] + 1; 35 dfs(y, x); 36 sz[x] += sz[y]; 37 } 38 } 39 ll merge(ll u, ll v) { 40 if (!v) return u; 41 if (!u) return v; 42 int x = ++num; 43 val[x] = val[u] + val[v]; 44 lson[x] = merge(lson[u], lson[v]); 45 rson[x] = merge(rson[u], rson[v]); 46 return x; 47 } 48 void solve(ll x, ll fa) { 49 for (ll i = head[x]; i; i = edge[i].pre) { 50 ll y = edge[i].to; 51 if (y == fa) continue; 52 solve(y, x); 53 rt[x] = merge(rt[x], rt[y]); 54 } 55 } 56 ll ask(ll cur, ll l, ll r, ll x, ll y) { 57 if (!cur) return 0; 58 if (x <= l && r <= y) return val[cur]; 59 ll mid = (l + r) >> 1; 60 ll ret = 0; 61 if (x <= mid) ret += ask(lson[cur], l, mid, x, y); 62 if (y > mid) ret += ask(rson[cur], mid + 1, r, x, y); 63 return ret; 64 } 65 int main() { 66 scanf("%lld%lld", &n, &q); 67 for (ll i = 1, u, v; i < n; i++) { 68 scanf("%lld%lld", &u, &v); 69 add(u, v); 70 add(v, u); 71 } 72 dfs(1, 0); 73 for (ll i = 1; i <= n; i++) { 74 ins(rt[i], 0, n, dep[i], sz[i] - 1); 75 } 76 solve(1, 0); 77 while (q--) { 78 ll p, k; 79 scanf("%lld%lld", &p, &k); 80 printf("%lld\n", min(dep[p], k) * (sz[p] - 1) + ask(rt[p], 0, n, dep[p] + 1, dep[p] + k)); 81 } 82 return 0; 83 }
原文:https://www.cnblogs.com/zcr-blog/p/12714931.html