树上差分实质是O(1)打标记 通过o(n)dfs从下到上将标记回溯并更新到每个节点 将前缀和推广到树链的优秀的数据结构
2015NOIP
#include <bits/stdc++.h>
#define ll long long
#define f first
#define s second
#define pii pair<int,int>
const int MAXN=3e5+10;
using namespace std;
ll read(){
ll x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch==‘-‘)f=-1;ch=getchar();}
while(isdigit(ch))x=x*10+ch-‘0‘,ch=getchar();
return f*x;
}
vector<pii>vec[MAXN];
int fa[MAXN],dep[MAXN],num[MAXN],son[MAXN],vul[MAXN],n,m;
ll dis[MAXN];
void dfs1(int v,int pre,int deep){
num[v]=1;fa[v]=pre;dep[v]=deep+1;
for(int i=0;i<vec[v].size();i++){
if(vec[v][i].f!=pre){
vul[vec[v][i].f]=vec[v][i].s;
dis[vec[v][i].f]=vec[v][i].s+dis[v];
dfs1(vec[v][i].f,v,deep+1);
num[v]+=num[vec[v][i].f];
if(son[v]==-1||num[son[v]]<num[vec[v][i].f])son[v]=vec[v][i].f;
}
}
}
int p[MAXN],fp[MAXN],cnt,tp[MAXN];
void dfs2(int v,int td){
p[v]=++cnt;fp[p[v]]=v;tp[v]=td;
if(son[v]!=-1)dfs2(son[v],td);
for(int i=0;i<vec[v].size();i++){
if(vec[v][i].f!=fa[v]&&vec[v][i].f!=son[v])dfs2(vec[v][i].f,vec[v][i].f);
}
}
int slove(int u,int v){
int uu=tp[u];int vv=tp[v];
while(uu!=vv){
if(dep[uu]<dep[vv])swap(uu,vv),swap(u,v);
u=fa[uu];uu=tp[u];
}
if(dep[u]>dep[v])swap(u,v);
return u;
}
typedef struct node{
int lca,u,v;ll len;
}node;
node d[MAXN];
int P[MAXN];
void dfs3(int v,int pre){
for(int i=0;i<vec[v].size();i++){
if(pre!=vec[v][i].f){
dfs3(vec[v][i].f,v);
P[v]+=P[vec[v][i].f];
}
}
}
bool check(ll t){
int num=0;ll maxn=0;
for(int i=1;i<=m;i++){
if(d[i].len>t){
num++;maxn=max(maxn,d[i].len);
P[d[i].lca]-=2;P[d[i].u]++;P[d[i].v]++;}
}
//cout<<t<<" "<<num<<" "<<m<<endl;
if(!num)return 1;
int maxn1=0;dfs3(1,0);
for(int i=1;i<=n;i++){
if(P[i]==num){
maxn1=max(maxn1,vul[i]);
}
}
for(int i=1;i<=n;i++)P[i]=0;
if(maxn-maxn1<=t)return 1;
return 0;
}
int main(){
n=read();m=read();
int u,v,t;
for(int i=1;i<=n;i++)son[i]=-1;
for(int i=1;i<n;i++)u=read(),v=read(),t=read(),vec[u].push_back(make_pair(v,t)),vec[v].push_back(make_pair(u,t));
dfs1(1,0,0);dfs2(1,1);
ll r=0;
for(int i=1;i<=m;i++)d[i].u=read(),d[i].v=read(),d[i].lca=slove(d[i].u,d[i].v),d[i].len=dis[d[i].u]+dis[d[i].v]-2*dis[d[i].lca],r=max(r,d[i].len);
ll l=0,ans=0;
while(l<=r){
ll mid=(l+r)>>1;
if(check(mid))ans=mid,r=mid-1;
else l=mid+1;
}
printf("%lld\n",ans);
return 0;
}
原文:https://www.cnblogs.com/wang9897/p/9130792.html