题意
一个$n \times n$的矩形,每个格子里可以填$[1,k]$内的整数,求保证每行每列的最小值为$1$的方案数
数据范围:$1 \le n \le 250,1 \le k \le 10^9$
题解
直接计算难以解决,考虑容斥原理
枚举有$i$行$j$列的最小值$> 1$,选择方案有$$\sum_{i=0}^n \sum_{j=0}^n (-1)^{i+j} C_n^i C_n^j$$
乘上每个格子的选择
- $$(k-1)^{ni+nj-ij}$$这$i$行$j$列必须$> 1$
- $$k^{n^2-ni-nj+ij}$$其余格子可以任意选
最终为$$\sum_{i=0}^n \sum_{j=0}^n (-1)^{i+j} C_n^i C_n^j (k-1)^{ni+nj-ij} k^{n^2-ni-nj+ij}$$
复杂度$O(n^2 \log{n})$
代码:
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long LL;
const int mod=1e9+7;
const int N=250+10;
int n,k;
int fac[N],invf[N];
int ans;
inline int Pow(int a,int b){
int res=1;
for(;b;b>>=1){
if(b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod;
}
return res;
}
inline int inv(int x){
return Pow(x,mod-2);
}
inline int C(int n,int m){
return 1ll*fac[n]*invf[m]%mod*invf[n-m]%mod;
}
int main(){
scanf("%d%d",&n,&k);
fac[0]=1;
for(int i=1;i<=n;i++)
fac[i]=1ll*fac[i-1]*i%mod;
invf[n]=inv(fac[n]);
for(int i=n-1;~i;i--)
invf[i]=1ll*invf[i+1]*(i+1)%mod;
for(int i=0;i<=n;i++)
for(int j=0;j<=n;j++){
int p=n*i+n*j-i*j;
ans=(1ll*ans+1ll*Pow(-1,i+j)*C(n,i)%mod*C(n,j)%mod*Pow(k-1,p)%mod*Pow(k,n*n-p)%mod)%mod;
}
((ans%=mod)+=mod)%=mod;
printf("%d",ans);
return 0;
}
优化:
对$$(k-1)^{ni+nj-ij} k^{n^2-ni-nj+ij}$$变形得
$$ \begin{aligned} &(k-1)^{(n-j)i} (k-1)^{nj} k^{(n-i)(n-j)} \\ =&[(k-1)^i k^{n-i}]^{n-j}[(k-1)^n]^j \end{aligned} $$
与前面的$(-1)^j C_n^j$一同,可以应用二项式定理,得$$[(k-1)^i k^{n-i}-(k-1)^n]^n$$
最终为$$\sum_{i=0}^n (-1)^i C_n^i [(k-1)^i k^{n-i}-(k-1)^n]^n$$
复杂度$O(n \log{n})$
代码:
#include<iostream>
#include<cstdio>
using namespace std;
typedef long long LL;
const int mod=1e9+7;
const int N=250+10;
int n,k;
int fac[N],invf[N];
int ans;
inline int Pow(int a,int b){
int res=1;
for(;b;b>>=1){
if(b&1)
res=1ll*res*a%mod;
a=1ll*a*a%mod;
}
return res;
}
inline int inv(int x){
return Pow(x,mod-2);
}
inline int C(int n,int m){
return 1ll*fac[n]*invf[m]%mod*invf[n-m]%mod;
}
int main(){
scanf("%d%d",&n,&k);
fac[0]=1;
for(int i=1;i<=n;i++)
fac[i]=1ll*fac[i-1]*i%mod;
invf[n]=inv(fac[n]);
for(int i=n-1;~i;i--)
invf[i]=1ll*invf[i+1]*(i+1)%mod;
for(int i=0;i<=n;i++)
ans=(1ll*ans+1ll*Pow(-1,i)*C(n,i)%mod*Pow(((1ll*Pow(k-1,i)*Pow(k,n-i)%mod-Pow(k-1,n))%mod+mod)%mod,n)%mod)%mod;
((ans%=mod)+=mod)%=mod;
printf("%d",ans);
return 0;
}