不难发现,答案可以分成两种:
整段的
中间删一点,两端凑一起的
考虑分开计算贡献。
如果 \(s\) 中存在子串等于 \(t\),那么自然,可以删左边的任何一段,或者右边的任何一段。
不妨设子串开始的位置为 \(i\),于是其贡献为 \((1 + 2 + \cdots + i - 1) + (1 + 2 + \cdots +(|s| - i - |t| + 1))\)。
接下来考虑中间删一点,两端凑一起的情况。
令 \(f_i\) 表示 \(s\) 从 \(i\) 开始与 \(t\) 的最长相同前缀的长度,\(g_i\) 表示 \(s\) 从 \(i\) 向前与 \(t\) 的最长相同后缀的长度。
NOTICE:由于需要排除第一种情况,所以对于 \(f, g\),都需要对于 \(|t| - 1\) 取 \(\min\)。
这部分可以通过哈希和二分完成(或者 Z 函数也行)
于是需要考虑 \(f, g\) 如何相互贡献。
不难发现,对于两个端点 \(i \le j\) 可以做出贡献,需要满足:
\(i + |t| - 1 \lt j\)。考虑中间必须删一点,所以必须严格小于,这样才不会重叠或者接在一起。
\(f_i + g_j \ge |t|\)。这样才能凑出完整的目标串。
那么其最终的贡献为 \(f_i + g_j - |t| + 1\)。
于是可以得到表达式:
不难发现可以倒着扫一遍,然后利用树状数组求和即可。
考场上以防万一,我用的双哈希……但好像有点多余。
#include <iostream>
#include <algorithm>
#include <string>
#include <cmath>
using namespace std;
const int N = 4e5 + 7, BASE = 131, mod = 1331;
string s, t;
using hI = unsigned long long;
using hP = unsigned int;
hI sha[N], tha[N];
hI sha2[N], tha2[N];
hI ofs[N], ofs2[N];
hI shash(int l, int r) {
hI sha1 = sha[r] - sha[l - 1] * ofs[r - l + 1];
sha1 += (sha2[r] + mod - sha2[l - 1] * ofs2[r - l + 1] % mod) % mod;
return sha1;
}
hI thash(int l, int r) {
hI tha1 = tha[r] - tha[l - 1] * ofs[r - l + 1];
tha1 += (tha2[r] + mod - tha2[l - 1] * ofs2[r - l + 1] % mod) % mod;
return tha1;
}
int f[N], g[N];
#define lowbit(i) (i & -i)
// 这是倒着的树状数组!
struct BIT {
long long b[N];
void update(int i, int x) {
for (; i; i -= lowbit(i)) b[i] += x;
}
long long query(int i) {
long long r = 0;
for(; i < N; i += lowbit(i)) r += b[i];
return r;
}
} cnt, sum;
long long get(long long x) {
return (1 + x) * x / 2;
}
int main() {
cin >> s >> t;
int n = s.size(), m = t.size();
for (int i = 1; i <= n; ++i) {
sha[i] = sha[i - 1] * BASE + s[i - 1] - 2;
sha2[i] = (sha2[i - 1] * 17 % mod + s[i - 1] - 2) % mod;
}
for(int i = 1; i <= m; ++i) {
tha[i] = tha[i - 1] * BASE + t[i - 1] - 2;
tha2[i] = (tha2[i - 1] * 17 % mod + t[i - 1] - 2) % mod;
}
ofs[0] = ofs2[0] = 1;
for (int i = 1, ie = max(s.size(), t.size()); i <= ie; ++i) {
ofs[i] = ofs[i - 1] * BASE;
ofs2[i] = ofs2[i - 1] * 17 % mod;
}
long long ans = 0;
// 简单倍增
int W = 1 << ((int)log2(t.size()) + 1);
for (int i = 1; i <= n; ++i) {
f[i] = g[i] = -1;
for (int w = W; w; w >>= 1) {
if (i + f[i] + w - 1 <= n && 1 + f[i] + w - 1 <= m
&& shash(i, i + f[i] + w - 1) == thash(1, 1 + f[i] + w - 1))
f[i] += w;
if (i - g[i] - w + 1 > 0 && m - g[i] - w + 1 > 0
&& shash(i - g[i] - w + 1, i) == thash(m - g[i] - w + 1, m))
g[i] += w;
}
if (f[i] >= m) ans += get(i - 1) + get(n - i - m + 1), --f[i];
}
for (int i = n; i > m; --i) {
if (g[i] >= m) --g[i];
cnt.update(g[i], 1);
sum.update(g[i], g[i]);
ans += sum.query(m - f[i - m]) + (f[i - m] - m + 1) * cnt.query(m - f[i - m]);
}
cout << ans << '\n';
}