一个长度为n的字符串,Ti代表其以i开始的后缀,求
![Rendered by QuickLaTeX.com \[\sum_{1 \leq i < j \leq n}len(T_i)+len(T_j)-2*lcp(T_i,T_j)\]](https://nocriz.com/wp-content/ql-cache/quicklatex.com-064778456d28a0726dae60ef8c4a0ce8_l3.png)
其中len为长度,lcp为最长公共前缀。
这道题是后缀数组的一道经典题。
首先将答案转化为下面的格式:
![Rendered by QuickLaTeX.com \[\sum_{1 \leq i < j \leq n}i+j -\sum_{1 \leq i < j \leq n}2*lcp(T_i,T_j)\]](https://nocriz.com/wp-content/ql-cache/quicklatex.com-a9f571e7b4064afebe86a2f0cfd743ab_l3.png)
进行了这一步转化之后,我们只需要求出字符串每对后缀的最长前缀即可。那么,这个又该如何求呢?我们知道,两对后缀的lcp就是其在高度数组对应区间中的最小值。这个结论在很多介绍后缀数组的博客上都有证明,这里不做说明。这样的话,我们只需要统计出高度数组中每个数在哪一段区间中是最小值即可。这个问题可以使用单调栈得以解决。
这样又带来了一个去重的问题,就是一个区间内的最小值可能有若干个。这个问题有一个简便的解法:只需将每个数所做为最小值的区间左端点设为严格大于它的第一个数,而右端点是不严格大于它的第一个数就行了。
下附代码
#include <cstdio>
#include <cstring>
#include <iostream>
#define N 1000010
#define rep(i,l,r) for(int i=l;i<=r;i++)
#define rrep(i,r,l) for(int i=r;i>=l;i--)
using namespace std;
typedef long long ll;
int t1[N],t2[N],sa[N],h[N],rk[N],c[10*N],a[N],m,n;
char s[N];
void calcsa(int n,int m){
int *x = t1,*y = t2,p = 0,f = 0;
memset(c,0,4*m+40);
rep(i,1,n)c[x[i]=a[i]]++;
rep(i,1,m)c[i]+=c[i-1];
rrep(i,n,1)sa[c[x[i]]--]=i;
for(int i=1;i<=n&&p<=n;i*=2){
p=0;
rep(j,n-i+1,n)y[++p]=j;
rep(j,1,n)if(sa[j]>i)y[++p]=sa[j]-i;
memset(c,0,4*m+40);
rep(j,1,n)c[x[y[j]]]++;
rep(j,1,m)c[j]+=c[j-1];
rrep(j,n,1)sa[c[x[y[j]]]--]=y[j];
swap(x,y);x[sa[1]]=1;p=2;
rep(j,2,n)x[sa[j]]=y[sa[j]]==y[sa[j-1]]&&y[sa[j]+i]==y[sa[j-1]+i]?p-1:p++;
m=p;
}
rep(i,1,n)rk[sa[i]]=i;
rep(i,1,n){
int j=sa[rk[i]-1];
if(f)f--;while(a[i+f]==a[j+f])f++;
h[rk[i]]=f;
}
}
int st[N],stn = 0,l[N],r[N];
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
cin>>s;
int len=strlen(s);
for(int i=0;i<len;i++)a[++n]=s[i]-' ';
calcsa(n,10000);
ll ans = 0;
for(int i=1;i<=n;i++) ans+=1ll*i*(i-1)+1ll*i*(i-1)/2;
h[1] = h[n+1] = -1;
st[0] = 1;stn=1;
for(int i=2;i<=n;i++){
while(h[st[stn-1]]>h[i])stn--;
l[i] = st[stn-1];
st[stn] = i;
stn++;
}
st[0] = n+1;stn=1;
for(int i=n;i>=2;i--){
while(h[st[stn-1]]>=h[i])stn--;
r[i] = st[stn-1];
st[stn] = i;
stn++;
}
for(int i=2;i<=n;i++)ans-=2ll*h[i]*(r[i]-i)*(i-l[i]);
cout<<ans<<endl;
return 0;
}
发表回复