<JZOJ5913>林下风气

快乐dp

反正考场写挂

#include<cstdio>
#include<cstring>
#include<cctype>
#include<iostream>
#define MOD 19260817
#define LL long long 
template <class T>inline void read(T &X)
{
    X=0;int W=0;char ch=0;
    while(!isdigit(ch))W|=ch=='-',ch=getchar();
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    X=W?-X:X;return;
}
int a[4000],n,k,head[4000],f[4000][2][2];//j最小值k最大值 
int cnt=0,mn,mx,ans;
bool boo[4000];
struct node{int to,next;}edge[8000];
void add(int u,int v)
{
    edge[++cnt].to=v;
    edge[cnt].next=head[u];
    head[u]=cnt;
}
void dfs(int x,int fa)
{
    if(mn==mx && a[x]==mn) f[x][1][1]=1; else
        if(a[x]==mx && a[x]==mn) f[x][1][0]=f[x][0][1]=f[x][1][1]=1; else
            if(a[x]==mx) f[x][0][1]=1; else
                if(a[x]==mn) f[x][1][0]=1; else f[x][0][0]=1;
    for(int i=head[x];i;i=edge[i].next)
    {
        int to=edge[i].to;
        if(to!=fa)
        {
            dfs(to,x);
            f[x][1][1]=f[x][1][1]* ( (LL)f[to][0][0] +f[to][0][1]+f[to][1][0]+f[to][1][1]+1 )%MOD;
            f[x][1][1]=(f[x][1][1]+ (LL)f[x][1][0] * (f[to][0][1]+f[to][1][1]) )%MOD;
            f[x][1][1]=(f[x][1][1]+ (LL)f[x][0][1] * (f[to][1][0]+f[to][1][1]) )%MOD;
            f[x][1][1]=(f[x][1][1]+ (LL)f[x][0][0] *f[to][1][1])%MOD;
            
            f[x][1][0]=(LL)f[x][1][0] * (f[to][1][0]+f[to][0][0]+1)%MOD;
            f[x][1][0]=(f[x][1][0]+ (LL)f[x][0][0] *f[to][1][0])%MOD;
            
            f[x][0][1]=(LL)f[x][0][1] * (f[to][0][1]+f[to][0][0]+1)%MOD;
            f[x][0][1]=(f[x][0][1]+ (LL)f[x][0][0] *f[to][0][1])%MOD;
            
            f[x][0][0]=(LL)f[x][0][0] *(f[to][0][0]+1)%MOD;
        }
    }
    if(a[x]<mn || a[x]>mx) f[x][0][0]=f[x][0][1]=f[x][1][0]=f[x][1][1]=0;
}
int main()
{
//    freopen("lkf.in","r",stdin);
//    freopen("lkf.out","w",stdout);
    read(n),read(k);
    for(int i=1;i<=n;i++)
        read(a[i]),boo[a[i]]=true;
    for(int i=1;i<n;i++)
    {
        int x,y ;
        read(x),read(y);
        add(x,y),add(y,x);
    }
    for(int i=0;i+k<=n;i++)
        if(boo[i] && boo[i+k])//差为k a[i]都小于n! 
        {
            mn=i,mx=i+k;//差为k 设为min&max 
            memset(f,0,sizeof(f));
            dfs(1,0);
            for(int j=1;j<=n;j++) ans=(ans+f[j][1][1])%MOD;
        }
    printf("%d",ans);
    return 0;
}
原文地址:https://www.cnblogs.com/pile8852/p/9818654.html