点分治

1.求<=k点对数,容斥法

/*
求树中距离不超过k的点对数 
暴力枚举两点,lca求的复杂度是O(n^2logn),这样很多次询问都是冗余的
那么选择重心作为根,问题分成两部分,求经过重心的距离<=k的点对+不经过重心的距离<=k的点对 
    先来求第一部分,计算所有点的深度,排序,O(nlogn)可以计算出距离<=k的过重心点对 
    但是这样还不是正确答案,因为还要容斥掉来自同一棵子树的非法点对,那么对这部分再算一次即可
再求第二部分,这部分其实等价于原问题的子问题,所以我们再去重心的每个子树里找重心,和上面一样求
    如果一个点已经被当过重心了,那么给它打个vis标记,之后不再访问 
这样最多递归O(logn) 次,所以总复杂度是O(n*logn*logn)  
*/
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 10005
using namespace std;
struct Edge{int to,nxt,w;}e[N<<1]; 
int head[N],tot,n,k,ans;
void add(int u,int v,int w){
    e[tot].to=v;e[tot].w=w;e[tot].nxt=head[u];head[u]=tot++;
}
int vis[N],size[N],f[N],root,sum;
void getsize(int u,int pre){
    size[u]=1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre||vis[v])continue;
        getsize(v,u);
        size[u]+=size[v];
    }
}
void getroot(int u,int pre){
    f[u]=1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre||vis[v])continue;
        getroot(v,u);
        f[u]=max(f[u],size[v]);
    }
    f[u]=max(f[u],sum-size[u]);
    if(f[u]<f[root])root=u;
}
int o[N],cnt;
void getdeep(int u,int pre,int dep){
    o[++cnt]=dep;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre||vis[v])continue;
        getdeep(v,u,dep+e[i].w);
    }
}

int calc(int u,int dep){
    cnt=0;
    getdeep(u,u,dep);
    sort(o+1,o+cnt+1);
    int l=1,r=cnt,res=0;
    while(l<r){
        if(o[l]+o[r]<=k)res+=r-l,l++;
        else r--;
    }
    return res;
}
void solve(int u){
    ans+=calc(u,0);vis[u]=1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(vis[v])continue;
        ans-=calc(v,e[i].w);
        sum=size[v];root=0;
        getsize(v,0);getroot(v,0);
        solve(root);
    }
}
void init(){
    tot=ans=0;
    memset(head,-1,sizeof head);
    memset(size,0,sizeof size);
    memset(vis,0,sizeof vis);
}
int main(){
    while(cin>>n>>k,n){
        init();
        for(int i=1;i<n;i++){
            int u,v,w;
            scanf("%d%d%d",&u,&v,&w);
            add(u,v,w);
            add(v,u,w);
        }
        f[0]=n;
        sum=n;root=0;
        getsize(1,0);getroot(1,0);
        solve(root);
        cout<<ans<<'
';
    }
}
View Code

2.求=k点对数,容斥法

/*
给定一棵边权树,每次给定一个询问x:长度为x的路径是否存在 
*/
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
using namespace std;
#define N 10005 
struct Eedge{int to,nxt,w;}e[N<<1];
int head[N],tot,n,m;
void add(int u,int v,int w){
    e[tot].to=v;e[tot].w=w;e[tot].nxt=head[u];head[u]=tot++;
}
int ans,x,root,size[N],f[N],sum,vis[N];
void getsize(int u,int pre){
    size[u]=1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre || vis[v])continue;
        getsize(v,u);
        size[u]+=size[v];
    }
}
void getroot(int u,int pre){
    f[u]=1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre||vis[v])continue;
        getroot(v,u);
        f[u]=max(f[u],size[v]);
    }
    f[u]=max(f[u],sum-size[u]);
    if(f[root]>f[u])root=u;
}
int o[N],cnt;
void getdeep(int u,int pre,int dep){
    o[++cnt]=dep;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(v==pre || vis[v])continue;
        getdeep(v,u,dep+e[i].w);
    }
}
int calc(int u,int dep){
    cnt=0;
    getdeep(u,0,dep);
    sort(o+1,o+1+cnt);
    int res=0,l=1,r=cnt;
    while(l<r){
        if(o[l]+o[r]==x){//这里要特别处理一下 
            if(o[l]==o[r]){
                res+=(r-l+1)*(r-l)/2;
                break;
            }
            int p=l,q=r;
            while(o[p]==o[l])p++;
            while(o[q]==o[r])q--;
            res+=(p-l)*(r-q);
            l=p;r=q;
        }
        else if(o[l]+o[r]<x)l++;
        else r--;
    }
    return res;
}
void solve(int u){
    ans+=calc(u,0);vis[u]=1;
    for(int i=head[u];i!=-1;i=e[i].nxt){
        int v=e[i].to;
        if(vis[v])continue;
        ans-=calc(v,e[i].w);
        sum=size[v];root=0;
        getsize(v,0);getroot(v,0);
        solve(root);
    }
}
void init(){
    tot=0;
    memset(head,-1,sizeof head);
}

int main(){
    while(cin>>n,n){
        init();
        for(int u=1;u<=n;u++){
            int v,w;
            while(scanf("%d",&v),v){
                scanf("%d",&w);
                add(u,v,w);add(v,u,w);
            }
        }
            
        while(scanf("%d",&x),x){
            ans=0;
            memset(vis,0,sizeof vis);
            memset(size,0,sizeof size);
            memset(f,0,sizeof f);
            f[0]=n;
            sum=n;root=0;
            getsize(1,0);getroot(1,0);
            solve(root);
            if(ans)cout<<"AYE
";
            else cout<<"NAY
";    
        }
        puts(".");
    }
}
View Code

3.求%1e6+3=k的对数,开桶记录

#include<bits/stdc++.h>
#pragma comment(linker,"/STACK:102400000,102400000")
#include<vector> 
using namespace std;
#define N 1000005
#define mod 1000003
#define ll long long
#define INF 0x3f3f3f3f
int inv[N];
void init(){
    inv[1]=1;
    for(int i=2;i<mod;i++)
        inv[i]=1ll*(mod-mod/i)*inv[mod%i]%mod,inv[i]=(inv[i]+2*mod)%mod;
}
vector<int>G[N];
int n,k;
ll a[N];
pair<int,int>ans;
int f[N],sum,root,size[N],cnt,vis[N];
struct Node{ll id,val;}o[N];
ll flag[N],tag,id[N];
void update(int a,int b){
    if(a>b)swap(a,b);
    ans=min(ans,make_pair(a,b));
}
void getsize(int u,int pre){
    size[u]=1;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(v==pre||vis[v])continue;
        getsize(v,u);
        size[u]+=size[v];
    }
}
void getroot(int u,int pre){
    f[u]=1;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(vis[v] || v==pre)continue;
        getroot(v,u);
        f[u]=max(f[u],size[v]);
    }
    f[u]=max(f[u],sum-size[u]);
    if(f[root]>f[u])root=u;
}
void getdeep(int u,int pre,ll dep){
    o[++cnt].val=dep;o[cnt].id=u;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(v==pre||vis[v])continue;
        getdeep(v,u,dep*a[v]%mod);
    }
}
void solve(int u){
    ++tag;vis[u]=1;
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(vis[v])continue;
        cnt=0;
        getdeep(v,0,a[v]);
        for(int j=1;j<=cnt;j++){
            Node cur=o[j];
            if(cur.val*a[u]%mod==k)
                update(u,cur.id);
            ll tmp=1ll*k*inv[cur.val*a[u]%mod]%mod;
            if(flag[tmp]==tag)
                update(id[tmp],cur.id);
        }
        for(int j=1;j<=cnt;j++){
            Node cur=o[j];
            if(flag[cur.val]!=tag||id[cur.val]>cur.id)
                flag[cur.val]=tag,id[cur.val]=cur.id;
        }
    }
    for(int i=0;i<G[u].size();i++){
        int v=G[u][i];
        if(vis[v])continue; 
        sum=size[v];root=0;
        getsize(v,0);getroot(v,0);
        solve(root);
    }
}
void clear(){
    tag=0;
    for(int i=1;i<=n;i++)G[i].clear();
    memset(vis,0,sizeof vis);
    ans.first=ans.second=INF;
    memset(flag,0,sizeof flag);
    memset(id,0,sizeof id);
}
int main(){
    init();
    while(cin>>n>>k){
        clear();
        for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
        for(int i=1;i<n;i++){
            int u,v;scanf("%d%d",&u,&v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        
        f[0]=n;
        sum=n;root=0;
        getsize(1,0);getroot(1,0);
        solve(root);
        if(ans.first==INF)
            puts("No solution");
        else cout<<ans.first<<" "<<ans.second<<'
'; 
    }
}
/*
5 1000001
1000002 1 1 1 2
1 2
2 3
3 4
4 5

*/
View Code
原文地址:https://www.cnblogs.com/zsben991126/p/11755976.html