【BZOJ3213】抛硬币(ZJOI2013)-期望DP+KMP+高精度

测试地址:抛硬币
做法:本题需要用到期望DP+KMP+高精度。
很容易想到,先用KMP求出信息,然后从一个点i,要么成功匹配第i+1个字符到达点i+1,要么匹配失败到达某个点fail(i+1)。于是令f(i)为生成出长度为i的前缀所需要的期望步数,有状态转移方程:
f(i)=f(i1)+1+(1pi)(f(i)f(fail(i)))
其中pi为抛硬币得到第i个字符的概率,而f(i)f(fail(i))就表示从fail(i)走到i的期望步数。移项后得到:
f(i)=f(i1)f(fail(i))+1pi
然后因为题目要求精确解,而答案中分数的分子和分母可能非常大,所以要用高精度(真是毒瘤……)。而在约分时,如果懒得写更相减损术,可以用2100进行试除,因为不难发现题目中所有涉及到的数都是由100以内的数凑出来的。
(据说有一种截然不同的DP思路,和这个DP可能可以得到不同的结果,但是这两种DP都能AC,到底哪种是对的呢……)
以下是本人代码:

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll w=100000000ll;
char s[1010];
int n,nxt[1010],fail[1010];

struct hd
{
    int siz;
    ll s[2010];

    void pushup()
    {
        for(int i=0;i<siz;i++)
            if (s[i]>=w)
            {
                s[i+1]+=s[i]/w;
                s[i]%=w;
                if (i==siz-1) siz++;
            }
    }

    void pushdown()
    {
        for(int i=siz-1;i>=1;i--)
        {
            if (s[i]) break;
            siz--;
        }
    }

    void output()
    {
        printf("%lld",s[siz-1]);
        for(int i=siz-2;i>=0;i--)
            printf("%08lld",s[i]);
    }
};

hd operator + (hd a,hd b)
{
    hd s;
    s=a;
    s.siz=max(a.siz,b.siz);
    for(int i=0;i<s.siz;i++)
        s.s[i]=a.s[i]+b.s[i];
    s.pushup();
    return s;
}

hd operator - (hd a,hd b)
{
    hd s;
    s=a;
    s.siz=max(a.siz,b.siz);
    for(int i=0;i<s.siz;i++)
        s.s[i]=a.s[i]-b.s[i];
    for(int i=0;i<s.siz;i++)
        if (s.s[i]<0) s.s[i]+=w,s.s[i+1]--;
    s.pushdown();
    return s;
}

hd operator * (hd a,hd b)
{
    hd s;
    memset(s.s,0,sizeof(s.s));
    for(int i=0;i<a.siz;i++)
        for(int j=0;j<b.siz;j++)
            s.s[i+j]+=a.s[i]*b.s[j];
    s.siz=a.siz+b.siz;
    s.pushup();
    s.pushdown();
    return s;
}

hd operator * (hd a,ll b)
{
    hd s;
    s=a;
    for(int i=0;i<s.siz;i++)
        s.s[i]*=b;
    s.pushup();
    return s;
}

hd operator / (hd a,ll b)
{
    hd s;
    memset(s.s,0,sizeof(s.s));
    s.siz=a.siz;
    ll now=0;
    for(int i=a.siz-1;i>=0;i--)
    {
        now=now*w+a.s[i];
        if (now>=b)
        {
            s.s[i]=now/b;
            now%=b;
        }
    }
    s.pushdown();
    return s;
}

int operator % (hd a,ll b)
{
    ll now=0;
    for(int i=a.siz-1;i>=0;i--)
    {
        now=now*w+a.s[i];
        if (now>=b) now%=b;
    }
    return now;
}

struct fraction
{
    hd a,b;

    void simplify()
    {
        for(ll i=2;i<=100;i++)
            while (a%i==0&&b%i==0)
            {
                a=a/i;
                b=b/i;
            }
    }
}p[2],f[2010];

fraction operator + (fraction a,fraction b)
{
    fraction s;
    s.a=a.a*b.b+a.b*b.a,s.b=a.b*b.b;
    s.simplify();
    return s;
}

fraction operator - (fraction a,fraction b)
{
    fraction s;
    s.a=a.a*b.b-a.b*b.a,s.b=a.b*b.b;
    s.simplify();
    return s;
}

fraction operator * (fraction a,fraction b)
{
    a.a=a.a*b.a;
    a.b=a.b*b.b;
    a.simplify();
    return a;
}

fraction operator / (fraction a,fraction b)
{
    a.a=a.a*b.b;
    a.b=a.b*b.a;
    a.simplify();
    return a;
}

fraction operator + (fraction a,ll b)
{
    a.a=a.a+(a.b*b);
    a.simplify();
    return a;
}

void kmp()
{
    int now=0;
    nxt[0]=0;
    fail[1]=0;
    for(int i=1;i<=n;i++)
    {
        now=nxt[i-1];
        while(now&&s[now+1]!=s[i]) now=nxt[now];
        if (i>1&&s[now+1]==s[i]) nxt[i]=now+1,now++;
        else nxt[i]=0;
        if (i<n)
        {
            while(now&&s[now+1]==s[i+1]) now=nxt[now];
            if (s[now+1]!=s[i+1]) fail[i+1]=now+1;
            else fail[i+1]=0;
        }
    }
}

int main()
{
    p[0].a.siz=p[0].b.siz=1;
    scanf("%d%d",&p[0].a.s[0],&p[0].b.s[0]);
    p[0].a.pushup(),p[0].b.pushup();
    p[1].b=p[0].b;
    p[1].a=p[0].b-p[0].a;

    s[0]='#';
    scanf("%s",s+1);
    n=strlen(s)-1;
    kmp();

    memset(f[0].a.s,0,sizeof(f[0].a.s));
    memset(f[0].b.s,0,sizeof(f[0].b.s));
    f[0].a.siz=f[0].b.siz=1;
    f[0].b.s[0]=1;
    for(int i=1;i<=n;i++)
    {
        bool t=(s[i]=='T');
        f[i]=(f[i-1]-(p[t^1]*f[fail[i]])+1)/p[t];
        f[i].simplify();
    }

    f[n].a.output();
    printf("/");
    f[n].b.output();

    return 0;
}
原文地址:https://www.cnblogs.com/Maxwei-wzj/p/9793338.html