[nowcoder5666D]Quadratic Form

首先猜测$sum_{i=1}^{n}b_{i}x_{i}$取到最小值时存在$x_{i}$满足$sum_{i=1}^{n}sum_{j=1}^{n}A_{i,j}x_{i}x_{j}=1$,否则一定可以适当调整某一个$x_{i}$
因此可以使用拉格朗日乘数法,构造函数$F(x_{1},x_{2},...,x_{n},lambda)=sum_{i=1}^{n}b_{i}x_{i}+lambda(sum_{i=1}^{n}sum_{j=1}^{n}A_{i,j}x_{i}x_{j}-1)$,要保证$x_{i}$和$lambda$的偏导数为0,由此可以列出$n+1$个方程:$,egin{cases}forall 1le ile n,b_{i}+2lambdasum_{j=1}^{n}A_{i,j}x_{j}=0\ sum_{i=1}^{n}sum_{j=1}^{n}A_{i,j}x_{i}x_{j}=1 end{cases}$
把A、b和x看成矩阵来表示,即得到$egin{cases}b+2lambda Ax=[0 0...0]\ x^{T}Ax=[1]end{cases}$,分别化简即得$egin{cases}x=-frac{1}{2lambda}A^{-1}b\ x^{T}x=A^{-1}end{cases}$,同时由于$bx^{T}=xb^{T}$,则答案$(xb^{T})^{2}=(bx^{T})(xb^{T})=b(x^{T}x)b^{T}=bA^{-1}b^{T}$,矩阵求逆+矩阵乘法即可,复杂度$o(n^{3})$
 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 205
 4 #define mod 998244353
 5 int n,ans,a[N][N],b[N][N],c[N],d[N];
 6 int ksm(int n,int m){
 7     if (!m)return 1;
 8     int s=ksm(n,m>>1);
 9     s=1LL*s*s%mod;
10     if (m&1)s=1LL*s*n%mod;
11     return s;
12 }
13 int main(){
14     while (scanf("%d",&n)!=EOF){
15         for(int i=1;i<=n;i++)
16             for(int j=1;j<=n;j++)scanf("%d",&a[i][j]);
17         for(int i=1;i<=n;i++)
18             for(int j=1;j<=n;j++)b[i][j]=(i==j);
19         for(int i=1;i<=n;i++){
20             for(int j=i;j<=n;j++)
21                 if (a[j][i]){
22                     for(int k=i;k<=n;k++)swap(a[i][k],a[j][k]);
23                     break;
24                 }
25             int t=ksm(a[i][i],mod-2);
26             for(int j=1;j<=n;j++){
27                 a[i][j]=1LL*t*a[i][j]%mod;
28                 b[i][j]=1LL*t*b[i][j]%mod;
29             }
30             for(int j=i+1;j<=n;j++){
31                 int t=a[j][i];
32                 for(int k=1;k<=n;k++){
33                     a[j][k]=(a[j][k]-1LL*t*a[i][k]%mod+mod)%mod;
34                     b[j][k]=(b[j][k]-1LL*t*b[i][k]%mod+mod)%mod;
35                 }
36             }
37         }
38         for(int i=n;i;i--)
39             for(int j=1;j<i;j++)
40                 for(int k=1;k<=n;k++)b[j][k]=(b[j][k]-1LL*b[i][k]*a[j][i]%mod+mod)%mod;
41         for(int i=1;i<=n;i++)scanf("%d",&c[i]);
42         int ans=0;
43         for(int i=1;i<=n;i++)
44             for(int j=1;j<=n;j++)ans=(ans+1LL*c[i]*c[j]%mod*b[j][i]%mod+mod)%mod;
45         printf("%d
",ans);
46     }
47 }
View Code
原文地址:https://www.cnblogs.com/PYWBKTDA/p/13322856.html