题目描述:
Give a tree with n vertices,each edge has a length(positive integer
less than 1001). Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if
and only if dist(u,v) not exceed k. Write a program that will count how
many pairs which are valid for a given tree. ## 输入 The
input contains several test cases. The first line of each test case
contains two integers n, k. (n<=10000) The following n-1 lines each
contains three integers u,v,l, which means there is an edge between node
u and v of length l. The last test case is followed by two zeros.
输出
For each test case output the answer on a single line.
思路:
比较普遍的一道点分治题,考虑每一棵树,以重心为根,预处理出每个点的深度,再把每个点扔到一个数组中进行线性计算,算出满足条件的所有点对,方法可以是将其排序,用两个指针从两边往中间推着计算。
不过这时候会有小问题,会多算一种情况,就是他们的LCA不是重心的情况,这时候就需要采用容斥原理的思想,在每个重心的子树中计算一遍上述的操作(注意加上重心到根节点的距离),再在答案中对应地减去,便能得到最终答案!
代码
```c++ #include #include #include using
namespace std; bool mem1; const int N=100005; struct Graph{ int
tot,to[N<<1],nxt[N<<1],len[N<<1],head[N]; void add(int
x,int y,int z){tot++;to[tot]=y;nxt[tot]=head[x];len[tot]=z;head[x]=tot;}
void clear(){tot=0;memset(head,-1,sizeof(head));} }G; bool vis[N]; int
ans,sz[N],mx[N],t_sz,center; int arr[N],dep[N]; int n,k; bool mem2; void
make_dep(int x,int f){ arr[++arr[0]]=dep[x]; for(int
i=G.head[x];i!=-1;i=G.nxt[i]){ int v=G.to[i]; if(v==f||vis[v])continue;
dep[v]=dep[x]+G.len[i]; make_dep(v,x); } } void get_center(int x,int f){
sz[x]=1,mx[x]=0; for(int i=G.head[x];i!=-1;i=G.nxt[i]){ int v=G.to[i];
if(v==f||vis[v])continue; get_center(v,x); sz[x]+=sz[v];
mx[x]=max(mx[x],sz[v]); } mx[x]=max(mx[x],t_sz-sz[x]);
if(!center||mx[x]<mx[center])center=x; } int calc(int x,int dis){
dep[x]=dis,arr[0]=0; make_dep(x,0); sort(arr+1,arr+arr[0]+1); int
j=arr[0],ret=0; for(int i=1;i<=arr[0];i++){
while(j>i&&arr[i]+arr[j]>k)j–; ret+=max(0,j-i); } return
ret; } void solve(int x){ vis[x]=1; ans+=calc(x,0); for(int
i=G.head[x];i!=-1;i=G.nxt[i]){ int v=G.to[i]; if(vis[v])continue;
ans-=calc(v,G.len[i]); center=0,t_sz=sz[v]; get_center(v,x);
solve(center); } } int main(){ while(scanf(“%d%d”,&n,&k)==2){
if(!n&&!k)break; G.clear(); memset(vis,0,sizeof vis); for(int
i=1;i<n;i++){ int x,y,z; scanf(“%d%d%d”,&x,&y,&z);
G.add(x,y,z),G.add(y,x,z); } center=0,t_sz=n,ans=0; get_center(1,0);
solve(center); printf(“%d”,ans); } return 0; }