codeforces 1042 e

逆推期望

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define pb(x) push_back(x)
const int maxn = 1e3+5;
const ll mod = 998244353;
struct node
{
    ll x,y;
    ll val;
    bool operator < ( const node &b) const
    {
        return val < b.val;
    }
};
node a[maxn*maxn];
ll sumr,sumr2,sumc,sumc2,sumdp;
//ll arr[maxn];
ll dp[maxn][maxn];
ll mul(ll a,ll b)
{
    return (a*b)%mod;
}
ll ksm(ll a,ll b)
{
    ll res = 1;
    while(b > 0)
    {
        if(b & 1) res = mul(res,a);
        a = mul(a,a);
        b >>= 1;
    }
    return res;
}
ll add(ll a,ll b)
{
    a += b;
    while(a >= mod) a -= mod;
    while(a < 0) a += mod;
    return a;
}
ll inv(ll a)
{
    ll ia = ksm(a,mod-2);
    assert(mul(a,ia) == 1);
    return ia;
}
int main()
{
    ll n,m;
    ll i,j,k;
    ll len;
    scanf("%lld %lld",&n,&m);
    len = 0;
    for(i=1;i<=n;++i)
    {
        for(j=1;j<=m;++j)
        {
           a[len].x = i;
           a[len].y = j;
           scanf("%lld",&a[len].val);
           len ++;
        }
    }
    sort(a,a+len);
    //for(i=0;i<len;++i)
      //  printf("%lld %lld %lld
",a[i].x,a[i].y,a[i].val);
    memset(dp,0,sizeof(dp));
    ll l,r;
    l = 0;
    sumr = sumr2 = sumc2 = sumc = sumdp = 0;
    while(l < n*m)
    {
        r = l;
        while(a[r].val == a[l].val && r < n*m) r ++;
        //cout << l << "  " << r << endl;
        ll il = -1;
        if(l != 0) il = inv(l);

        for(i=l;i<r;++i)
        {
            ll rr,cc;
            rr = a[i].x; cc = a[i].y;
            if(il == -1)
            {
                dp[rr][cc] = 0;
                continue;
            }
            dp[rr][cc] = add(dp[rr][cc],mul(sumdp,il));
            dp[rr][cc] = add(dp[rr][cc],mul(rr,rr));
            dp[rr][cc] = add(dp[rr][cc],mul(cc,cc));
            dp[rr][cc] = add(dp[rr][cc],mul(sumr2,il));
            dp[rr][cc] = add(dp[rr][cc],mul(sumc2,il));
            dp[rr][cc] = add(dp[rr][cc],mul(mul(-2*rr,sumr),il));
            dp[rr][cc] = add(dp[rr][cc],mul(mul(-2*cc,sumc),il));

        }
        for(i = l; i < r; ++i)
        {
            int rr,cc;
            rr = a[i].x; cc = a[i].y;
            sumdp = add(sumdp,dp[rr][cc]);
            sumr2 = add(sumr2,mul(rr,rr));
            sumc2 = add(sumc2,mul(cc,cc));
            sumr = add(sumr,rr);
            sumc = add(sumc,cc);
        }
        l = r;
    }
    ll c,b;
    scanf("%lld %lld",&c,&b);
   // cout << endl;
    cout << dp[c][b] << endl;
}
/*
1 4
1 1 2 1
1 3


2 3
1 5 7
2 3 1
1 2
*/

  

这题是真的痛苦   

从各个val低于指定位置val的点,向指定位置去推

至于为什么要用x、x²等前缀和,写下公式多看下就懂了

原文地址:https://www.cnblogs.com/mltang/p/9692735.html