CF1073E Segment Sum 自闭了

CF1073E Segment Sum

题意翻译

给定(K,L,R),求(L)~(R)之间最多不包含超过(K)个数码的数的和。

(K<=10,L,R<=1e18)


我 写 晕 了

我 自 闭 了

根本不知道自己在写什么????

告辞。。。


错误Code:

#include <cstdio>
#define ll long long
const ll mod=998244353;
ll dp[2][20][1<<10],cnt[2][20][1<<10],po[20],l,r;
int k;
void init()
{
    po[0]=1;
    for(int i=1;i<=18;i++) po[i]=po[i-1]*10%mod;
    cnt[0][0][0]=1;
    for(int i=1;i<=18;i++)
        for(int s=0;s<1<<10;s++)
        {
            for(int j=1;j<=9;j++)
                if(s>>j&1)
                {
                    int t=(1<<j)^s;
                    (cnt[1][i][s]+=(cnt[0][i-1][s]+cnt[1][i-1][s]
                                   +cnt[0][i-1][t]+cnt[1][i-1][t]))%=mod;

                }
            int t=s^1;
            if(s&1)
                (cnt[0][i][s]+=(cnt[0][i-1][s]+cnt[1][i-1][s]
                               +cnt[0][i-1][t]+cnt[1][i-1][t]))%=mod;
        }
    for(int i=1;i<=18;i++)
        for(int s=0;s<1<<10;s++)
        {
            for(int j=1;j<=9;j++)
                if(s>>j&1)
                {
                    int t=(1<<j)^s;
                    ll tmp0=(cnt[0][i-1][s]+cnt[1][i-1][s]
                            +cnt[0][i-1][t]+cnt[1][i-1][t])%mod;
                    ll tmp1=(dp[0][i-1][s]+dp[1][i-1][s]
                            +dp[0][i-1][t]+dp[1][i-1][t])%mod;
                    (dp[1][i][s]+=(tmp1+j*po[i-1]%mod*tmp0))%=mod;

                }
            if(s&1)
            {
                int t=s^1;
                ll tmp1=(dp[0][i-1][s]+dp[1][i-1][s]
                        +dp[0][i-1][t]+dp[1][i-1][t])%mod;
                dp[0][i][s]=tmp1;
            }
        }
}
int bit[20];
ll solve(ll d)
{
    int rt=0;
    for(ll i=d;i;i/=10) bit[++rt]=i%10;
    int cho=0;//高位已选状态
    ll res=0,ans=0;//高位剩余
    for(int i=rt;i;i--)
    {
        for(int s=0;s<1<<10;s++)
        {
            int ct=0;
            for(int t=s;t;t>>=1) ct+=t&1;
            if(ct>k+1) continue;
            if(s&1)
            {
                int t=s^1;
                ll tmp1=(dp[0][i-1][s]+dp[1][i-1][s]
                        +dp[0][i-1][t]+dp[1][i-1][t])%mod;
                ans+=tmp1;
            }
            if(ct>k) continue;
            for(int j=1;j<bit[i]+(i==1);j++)//高位选填
                if((s>>j&1)&&((s|cho)==s))
                {
                    int t=(1<<j)^s;
                    ll tmp0=(cnt[0][i-1][s]+cnt[1][i-1][s]
                            +cnt[0][i-1][t]+cnt[1][i-1][t])%mod;
                    ll tmp1=(dp[0][i-1][s]+dp[1][i-1][s]
                            +dp[0][i-1][t]+dp[1][i-1][t])%mod;
                    (ans+=tmp1+tmp0*(j+res)%mod*po[i-1]%mod)%=mod;
                }
        }
        res=(res+bit[i])*10%mod;
        cho|=1<<bit[i];
    }
    return ans;
}
int main()
{
    scanf("%lld%lld%d",&l,&r,&k);
    init();
    printf("%lld
",((solve(r)-solve(l-1))%mod+mod)%mod);
    return 0;
}

updata2019.2.9

写了好几个月,终于肝出来了

Code:

#include <cstdio>
#include <cstring>
#define ll long long
const int mod=998244353;
inline int add(int a,int b){return a+b>=mod?a+b-mod:a+b;}
#define mul(a,b) (1ll*(a)*(b)%mod)
int po[20],bit[20],len,k;
struct node
{
    int val,cnt;
    node(){}
    node(int v,int c){val=v,cnt=c;}
    node friend operator +(node a,node b){return node(add(a.val,b.val),add(a.cnt,b.cnt));}
}dp[20][1<<10];
node dfs(int pos,int sta,int lead,int lim)//前导0和最高位限制
{
	int cnt=0;
	for(int i=0;i<10;i++) cnt+=sta>>i&1;
	if(cnt>pos) return node(0,0);
	if(!pos) return node(0,1);
	if(!lim&&!lead&&~dp[pos][sta].val) return dp[pos][sta];
	node ret=node(0,0),bee;
	if(lead) ret=ret+dfs(pos-1,sta,lead,lim&&!bit[pos]);
	else if(sta&1) ret=ret+dfs(pos-1,sta,lead,lim&&!bit[pos])+dfs(pos-1,sta^1,lead,lim&&!bit[pos]);
	for(int i=1,up=lim?bit[pos]:9;i<=up;i++)
		if(sta>>i&1)
		{
		    bee=dfs(pos-1,sta,0,lim&&i==up)+dfs(pos-1,sta^(1<<i),0,lim&&i==up);
		    ret=ret+bee;
            ret.val=add(ret.val,mul(bee.cnt,mul(i,po[pos-1])));
		}
	return !lim&&!lead?dp[pos][sta]=ret:ret;
}
int cal(ll x)
{
	len=0;while(x) bit[++len]=x%10,x/=10;
	memset(dp,-1,sizeof dp);int ans=0;
	for(int s=0;s<1<<10;s++)
	{
		int cnt=0;
		for(int i=0;i<10;i++) cnt+=s>>i&1;
		if(cnt<=k) ans=add(ans,dfs(len,s,1,1).val);
	}
	return ans;
}
int main()
{
	ll l,r;
	scanf("%lld%lld%d",&l,&r,&k);
	po[0]=1;for(int i=1;i<=18;i++) po[i]=mul(po[i-1],10);
	printf("%d
",add(cal(r),mod-cal(l-1)));
	return 0;
}

原文地址:https://www.cnblogs.com/butterflydew/p/9883665.html