GYM100376E.LinearMapReduce(线段树维护矩阵乘法)

题意:

给出一个初始矩阵,和n个不同的中间矩阵,每次询问初始矩阵从第a个矩阵乘到第b个矩阵的答案。

题解:

用线段树维护矩阵乘法,矩阵乘法不存在交换律,只存在结合律,这回刻到DNA里了。(这里询问包含两个方向,所以要开两颗线段树维护)

#include<bits/stdc++.h>
using namespace std;
const int maxn=1e5+100;
const int mod=1e9+7;
typedef long long ll;
struct matrix {
    ll m[3][3];
};
void ccpy (matrix &x,matrix y) {
    for (int i=1;i<=2;i++) for (int j=1;j<=2;j++) x.m[i][j]=y.m[i][j];
}
void mul (matrix &ans,matrix a,matrix b) {
    for (int i=1;i<=2;i++) for (int j=1;j<=2;j++) ans.m[i][j]=0;
    for (int k=1;k<=2;k++) {
        for (int i=1;i<=2;i++) {
            for (int j=1;j<=2;j++) {
                ans.m[i][j]+=a.m[i][k]*b.m[k][j]%mod;
                ans.m[i][j]%=mod;
            }
        }
    }
}
struct node {
    int l,r;
    matrix sum;
}segTree[maxn<<2][2];
ll K,V;
int n,q;
ll a[maxn][3][3];
void build (int i,int l,int r,int f) {
    segTree[i][f].l=l;
    segTree[i][f].r=r;
    if (l==r) {
        segTree[i][f].sum.m[1][1]=a[l][1][1];
        segTree[i][f].sum.m[2][1]=a[l][1][2];
        segTree[i][f].sum.m[1][2]=a[l][2][1];
        segTree[i][f].sum.m[2][2]=a[l][2][2];
        return;
    }
    int mid=(l+r)>>1;
    build(i<<1,l,mid,f);
    build(i<<1|1,mid+1,r,f);
    if (f==0)
        mul(segTree[i][f].sum,segTree[i<<1][f].sum,segTree[i<<1|1][f].sum);
    else
        mul(segTree[i][f].sum,segTree[i<<1|1][f].sum,segTree[i<<1][f].sum);
}
matrix query (int i,int l,int r,int f) {
    //区间查询矩阵和 
    if (segTree[i][f].l>=l&&segTree[i][f].r<=r) {
        return segTree[i][f].sum;
    }
    matrix ans;
    int mid=(segTree[i][f].l+segTree[i][f].r)>>1;
    if (r<=mid) return query(i<<1,l,r,f);
    else if (l>mid) return query(i<<1|1,l,r,f);
    else {
        matrix ans;
        if (f==0)
            mul(ans,query(i<<1,l,r,f),query(i<<1|1,l,r,f));
        else
            mul(ans,query(i<<1|1,l,r,f),query(i<<1,l,r,f));
        return ans;
    }
    return ans;
}
int main () {
    scanf("%d%d",&n,&q);
    for (int i=1;i<=n;i++) for (int j=1;j<=2;j++) for (int k=1;k<=2;k++) scanf("%lld",&a[i][j][k]);
    build(1,1,n,0);
    build(1,1,n,1);
    while (q--) {
        ll k,v,a,b;
        scanf("%lld%lld%lld%lld",&k,&v,&a,&b);
        matrix ans;
        ans.m[1][1]=k;
        ans.m[2][1]=0;
        ans.m[1][2]=v;
        ans.m[2][2]=0;
        matrix q1;
        if (a<b)
            q1=query(1,a,b,0);
        else
            q1=query(1,b,a,1);
        //printf("%lld %lld
%lld %lld

",q1.m[1][1],q1.m[1][2],q1.m[2][1],q1.m[2][2]);
        
        //printf("%lld %lld
%lld %lld

",ans.m[1][1],ans.m[1][2],ans.m[2][1],ans.m[2][2]);
        matrix tt;
        mul(tt,ans,q1);

        printf("%lld %lld
",tt.m[1][1],tt.m[1][2]);
    }
}
原文地址:https://www.cnblogs.com/zhanglichen/p/14585035.html