CF1228E Another Filling the Grid

@Pelom  October 11, 2021

题意

一个$n \times n$的矩形,每个格子里可以填$[1,k]$内的整数,求保证每行每列的最小值为$1$的方案数

数据范围:$1 \le n \le 250,1 \le k \le 10^9$

题解

直接计算难以解决,考虑容斥原理
枚举有$i$行$j$列的最小值$> 1$,选择方案有$$\sum_{i=0}^n \sum_{j=0}^n (-1)^{i+j} C_n^i C_n^j$$
乘上每个格子的选择

  • $$(k-1)^{ni+nj-ij}$$这$i$行$j$列必须$> 1$
  • $$k^{n^2-ni-nj+ij}$$其余格子可以任意选

最终为$$\sum_{i=0}^n \sum_{j=0}^n (-1)^{i+j} C_n^i C_n^j (k-1)^{ni+nj-ij} k^{n^2-ni-nj+ij}$$

复杂度$O(n^2 \log{n})$

代码:

#include<iostream>
#include<cstdio>
using namespace std;
typedef long long LL;
const int mod=1e9+7;
const int N=250+10;
int n,k;
int fac[N],invf[N];
int ans;
inline int Pow(int a,int b){
    int res=1;
    for(;b;b>>=1){
        if(b&1)
            res=1ll*res*a%mod;
        a=1ll*a*a%mod;
    }
    return res;
}
inline int inv(int x){
    return Pow(x,mod-2);
}
inline int C(int n,int m){
    return 1ll*fac[n]*invf[m]%mod*invf[n-m]%mod;
}
int main(){
    scanf("%d%d",&n,&k);
    fac[0]=1;
    for(int i=1;i<=n;i++)
        fac[i]=1ll*fac[i-1]*i%mod;
    invf[n]=inv(fac[n]);
    for(int i=n-1;~i;i--)
        invf[i]=1ll*invf[i+1]*(i+1)%mod;
    for(int i=0;i<=n;i++)
        for(int j=0;j<=n;j++){
            int p=n*i+n*j-i*j;
            ans=(1ll*ans+1ll*Pow(-1,i+j)*C(n,i)%mod*C(n,j)%mod*Pow(k-1,p)%mod*Pow(k,n*n-p)%mod)%mod;
        }
    ((ans%=mod)+=mod)%=mod;
    printf("%d",ans);
    return 0;
}

优化:

对$$(k-1)^{ni+nj-ij} k^{n^2-ni-nj+ij}$$变形得

$$ \begin{aligned} &(k-1)^{(n-j)i} (k-1)^{nj} k^{(n-i)(n-j)} \\ =&[(k-1)^i k^{n-i}]^{n-j}[(k-1)^n]^j \end{aligned} $$

与前面的$(-1)^j C_n^j$一同,可以应用二项式定理,得$$[(k-1)^i k^{n-i}-(k-1)^n]^n$$
最终为$$\sum_{i=0}^n (-1)^i C_n^i [(k-1)^i k^{n-i}-(k-1)^n]^n$$

复杂度$O(n \log{n})$

代码:

#include<iostream>
#include<cstdio>
using namespace std;
typedef long long LL;
const int mod=1e9+7;
const int N=250+10;
int n,k;
int fac[N],invf[N];
int ans;
inline int Pow(int a,int b){
    int res=1;
    for(;b;b>>=1){
        if(b&1)
            res=1ll*res*a%mod;
        a=1ll*a*a%mod;
    }
    return res;
}
inline int inv(int x){
    return Pow(x,mod-2);
}
inline int C(int n,int m){
    return 1ll*fac[n]*invf[m]%mod*invf[n-m]%mod;
}
int main(){
    scanf("%d%d",&n,&k);
    fac[0]=1;
    for(int i=1;i<=n;i++)
        fac[i]=1ll*fac[i-1]*i%mod;
    invf[n]=inv(fac[n]);
    for(int i=n-1;~i;i--)
        invf[i]=1ll*invf[i+1]*(i+1)%mod;
    for(int i=0;i<=n;i++)
        ans=(1ll*ans+1ll*Pow(-1,i)*C(n,i)%mod*Pow(((1ll*Pow(k-1,i)*Pow(k,n-i)%mod-Pow(k-1,n))%mod+mod)%mod,n)%mod)%mod;
    ((ans%=mod)+=mod)%=mod;
    printf("%d",ans);
    return 0;
}

添加新评论