POJ

题目链接:传送门 (POJ是真的烦)

题目思路:

对于dsu on tree 直接暴力统计深度--即u到根节点的距离(树状数组维护桶,也可以用 排序双指针--但单步容斥来得到合法答案),在子树中查询的查询 k - (deep[u] - dis)+ dis +1 ,其中 dis 为子树根的深度,+1是由于树状数组的起点为1,而深度存在为0,故存入树状时整体平移一个单位,deep[u]-dis ,即u到子树根的距离,+dis 是由于查询的子树中的 deep[v] = dis + v到子树根的距离 ,要查询的是v到子树根的距离 但树状存的是deep[v] (deep[v] 和 v到子树根的距离 是一一对应的),因此查询的时候要加一个dis。

代码:

  1 #include<functional>
  2 #include<algorithm>
  3 #include<cmath>
  4 #include<cstdio>
  5 #include<cctype>
  6 #include<cstring>
  7 #include<vector>
  8 using namespace std;
  9 typedef long long LL;
 10 typedef unsigned long long uLL;
 11 typedef pair<int,int> pii;
 12 typedef pair<LL,LL> pLL;
 13 typedef pair<double,double> pdd;
 14 const int N=2e4+5;
 15 const int M=1e7+5;
 16 const int inf=0x3f3f3f3f;
 17 const LL mod=998244353;
 18 const double eps=1e-8;
 19 const long double pi=acos(-1.0L);
 20 #define ls (i<<1)
 21 #define rs (i<<1|1)
 22 #define fi first
 23 #define se second
 24 #define pb push_back
 25 #define eb emplace_back
 26 #define mk make_pair
 27 #define mem(a,b) memset(a,b,sizeof(a))
 28 LL read()
 29 {
 30     LL x=0,t=1;
 31     char ch;
 32     while(!isdigit(ch=getchar())) if(ch=='-') t=-1;
 33     while(isdigit(ch)){ x=10*x+ch-'0'; ch=getchar(); }
 34     return x*t;
 35 }
 36 int son[N],sz[N],c[M],deep[N],ans,n,k,res;
 37 int m=1e7;
 38 vector<pii> e[N];
 39 inline int lowbit(int x)
 40 {
 41     return x&(-x);
 42 }
 43 void update(int x,int y)
 44 {
 45     for(int i=x;i<=m;i+=lowbit(i)) c[i]+=y;
 46 }
 47 int query(int x)
 48 {
 49     int tmp=0;
 50     for(int i=x;i;i-=lowbit(i)) tmp+=c[i];
 51     return tmp;
 52 }
 53 void dfs(int u,int pre)
 54 {
 55     sz[u]=1;
 56     son[u]=0;
 57     for(int i=0;i<e[u].size();i++)
 58     {
 59         pii x=e[u][i];
 60         int v=x.fi;
 61         if(v==pre) continue;
 62         deep[v]=deep[u]+x.se;
 63         dfs(v,u);
 64         sz[u]+=sz[v];
 65         if(sz[son[u]]<sz[v]) son[u]=v;
 66     }
 67 }
 68 void cal(int u,int pre,int dis)
 69 {
 70     if(k+1-deep[u]+2*dis<=0) return ;
 71     res+=query(k+1-deep[u]+2*dis);
 72     for(int i=0;i<e[u].size();i++)
 73     {
 74         pii x=e[u][i];
 75         int v=x.fi;
 76         if(v==pre) continue;
 77         cal(v,u,dis);
 78     }
 79 }
 80 void doit(int u,int pre,int x)
 81 {
 82     if(deep[u]+1>m) return ;
 83     update(deep[u]+1,x);
 84     for(int i=0;i<e[u].size();i++)
 85     {
 86         pii t=e[u][i];
 87         int v=t.fi;
 88         if(v==pre) continue;
 89         doit(v,u,x);
 90     }
 91 }
 92 void dfs2(int u,int pre,int flag)
 93 {
 94     for(int i=0;i<e[u].size();i++)
 95     {
 96         pii x=e[u][i];
 97         int v=x.fi;
 98         if(v==pre||v==son[u]) continue;
 99         dfs2(v,u,1);
100     }
101     if(son[u]) dfs2(son[u],u,0);
102     res=query(k+1+deep[u]);
103     update(deep[u]+1,1);
104     for(int i=0;i<e[u].size();i++)
105     {
106         pii x=e[u][i];
107         int v=x.fi;
108         if(v==pre||v==son[u]) continue;
109         cal(v,u,deep[u]);
110         doit(v,u,1);
111     }
112     ans+=res;
113     if(flag) doit(u,pre,-1);
114 }
115 int main()
116 {
117     while(scanf("%d%d",&n,&k)==2&&(n||k))
118     {
119         for(int i=1;i<=n;i++) e[i].clear();
120         for(int i=1;i<n;i++)
121         {
122             int x=read(),y=read(),z=read();
123             e[x].pb(mk(y,z));
124             e[y].pb(mk(x,z));
125         }
126 
127         ans=res=0;
128         dfs(1,0);
129         dfs2(1,0,1);
130         printf("%d
",ans);
131     }
132     return 0;
133 }
View Code

对于点分治,直接排序双指针 统计合法答案即可(由于 dsu on tree 用了一次桶 ,这里就用排序双指针实现,这两种实现方式效率是差不多的,不过排序双指针更好理解和实现)

代码:

  1 #include<functional>
  2 #include<algorithm>
  3 #include<cmath>
  4 #include<cstdio>
  5 #include<cctype>
  6 #include<cstring>
  7 #include<vector>
  8 using namespace std;
  9 typedef long long LL;
 10 typedef unsigned long long uLL;
 11 typedef pair<int,int> pii;
 12 typedef pair<LL,LL> pLL;
 13 typedef pair<double,double> pdd;
 14 const int N=1e4+5;
 15 const int M=1e7+5;
 16 const int inf=0x3f3f3f3f;
 17 const LL mod=998244353;
 18 const double eps=1e-8;
 19 const long double pi=acos(-1.0L);
 20 #define ls (i<<1)
 21 #define rs (i<<1|1)
 22 #define fi first
 23 #define se second
 24 #define pb push_back
 25 //#define eb emplace_back
 26 #define mk make_pair
 27 #define mem(a,b) memset(a,b,sizeof(a))
 28 LL read()
 29 {
 30     LL x=0,t=1;
 31     char ch;
 32     while(!isdigit(ch=getchar())) if(ch=='-') t=-1;
 33     while(isdigit(ch)){ x=10*x+ch-'0'; ch=getchar(); }
 34     return x*t;
 35 
 36 }
 37 int n,k;
 38 vector<pii> e[N];
 39 int res,rt,ans,vis[N],a[N],cnt,sz[N];
 40 void dfs(int u,int pre,int tot)
 41 {
 42     int ma=0;
 43     sz[u]=1;
 44     for(int i=0;i<e[u].size();i++)
 45     {
 46         pii x=e[u][i];
 47         int v=x.fi,w=x.se;
 48         if(v==pre||vis[v]) continue;
 49         dfs(v,u,tot);
 50         sz[u]+=sz[v];
 51         ma=max(ma,sz[v]);
 52     }
 53     ma=max(ma,tot-sz[u]);
 54     if(ma<res) res=ma,rt=u;
 55 }
 56 void doit(int u,int pre,int dis)
 57 {
 58     a[++cnt]=dis;
 59     for(int i=0;i<e[u].size();i++)
 60     {
 61         pii x=e[u][i];
 62         if(!vis[x.fi]&&x.fi!=pre) doit(x.fi,u,dis+x.se);
 63     }
 64 }
 65 int cal(int u,int pre,int dis)
 66 {
 67     cnt=0;
 68     int tmp=0;
 69     doit(u,pre,dis);
 70     sort(a+1,a+cnt+1);
 71     //for(int i=1;i<=cnt;i++) printf("%d%c",a[i],i==cnt?'
':' ');
 72     int l=1,r=cnt;
 73     for(;l<r;l++)
 74     {
 75         while(l<r&&a[l]+a[r]>k) r--;
 76         tmp+=r-l;
 77     }
 78     //printf("tmp = %d
",tmp);
 79     return tmp;
 80 }
 81 void solve(int u)
 82 {
 83     //printf("u = %d
",u);
 84     ans+=cal(u,0,0);
 85     for(int i=0;i<e[u].size();i++)
 86     {
 87         pii x=e[u][i];
 88         int v=x.fi,w=x.se;
 89         if(vis[v]) continue;
 90         ans-=cal(v,u,w);
 91     }
 92     vis[u]=1;
 93     for(int i=0;i<e[u].size();i++)
 94     {
 95         pii x=e[u][i];
 96         int v=x.fi,w=x.se;
 97         if(vis[v]) continue;
 98         res=inf;
 99         dfs(v,u,sz[v]);
100         solve(rt);
101     }
102 }
103 void init()
104 {
105     ans=0;
106     for(int i=1;i<=n;i++) vis[i]=0;
107     for(int i=1;i<=n;i++) e[i].clear();
108 }
109 int main()
110 {
111     while(~scanf("%d%d",&n,&k)&&(n||k))
112     {
113         init();
114         for(int i=1;i<n;i++)
115         {
116             int x=read(),y=read(),z=read();
117             e[x].pb(mk(y,z));
118             e[y].pb(mk(x,z));
119         }
120         res=inf;
121         dfs(1,0,n);
122         solve(rt);
123         printf("%d
",ans);
124     }
125     return 0;
126 }
View Code
原文地址:https://www.cnblogs.com/DeepJay/p/13971419.html