虽然已经有官方题解,这个算是个稍微详细那么一点的题解吧
题意:
给出一个串S,问最多可以选出多少个子串使得选出的子串两两不同构,同构的定义是,两个字符串A和B,通过一个映射函数f,让B的每个字符c通过映射函数c = f(c)后得到B',如果A == B',则A,B同构。例如ab和bc同构,映射函数是f(a) = b,f(b) = c,f(c) = a
思路:这题其实就是算S有多少个不同构的子串。先看个小例子,aab、bbc、aac,怎么算这3个串中有多少个不同构的串?这题字符只有abc三种,则映射函数只有3! = 6种(即将abc做全排列 后与abc一一对应即可得到映射函数),那我们可以按照映射函数,将每个串通过每个映射后得到的串放在一起然后去重,得到aab、aac、bba、bbc、cca、ccb,6个串,这样,得到的结果会算多6次,因为映射函数6种,而有个特殊情况是单字符的时候,aa、bb、cc,这时候只会算多3次,那答案就是(不同子串数量 + 3 × 单一字符的串数量) / 6
解法:个人用SAM计算不同子串,先将S通过映射函数得到6个串,S1、S2、...、S6,拼接得到T = S1:S2:...:S6,其中':'是除['a', 'z']字符以外的任意字符,接着跑个SAM,那SAM如何算多个串的不同子串个数呢(也就是T中不含':'的不同子串个数),我们知道某个串x是T的子串的话,通过起始状态跑x能够跑完,也就是说T的所有子串都可以通过起始状态跑到,那做个dp,dp[i]表示i状态与i的所有后续状态的路径总数(注意不跑拼接字符':'那条边),记忆化一下,得到不同子串个数。单个字符的串的数量则通过跑单个字符'a'一直跑,跑到最大长度时长度即个数
更新代码:
#include<iostream> #include<cstring> #include<vector> #include<cstdio> #include<algorithm> #define rep(i,e) for(int i=0;i<(e);i++) #define PB push_back #define scd(a) scanf("%d",&a) using namespace std; typedef long long ll; const int N = 1e6+10; int idx; int maxlen[N], minlen[N], trans[N][27], slink[N]; int new_state(int _maxlen, int _minlen, int* _trans, int _slink) { maxlen[idx] = _maxlen; minlen[idx] = _minlen; for(int i = 0; i < 27; i++) { if(_trans == NULL) trans[idx][i] = -1; else trans[idx][i] = _trans[i]; } slink[idx] = _slink; return idx++; } int add_char(char ch, int u) { int c = ch - 'a'; int z = new_state(maxlen[u] + 1, -1, NULL, -1); while(u != -1 && trans[u][c] == -1) { trans[u][c] = z; u = slink[u]; } if(u == -1) { minlen[z] = 1; slink[z] = 0; return z; } int x = trans[u][c]; if(maxlen[u] + 1 == maxlen[x]) { minlen[z] = maxlen[x] + 1; slink[z] = x; return z; } int y = new_state(maxlen[u] + 1, -1, trans[x], slink[x]); minlen[z] = minlen[x] = maxlen[y] + 1; slink[z] = slink[x] = y; while(u != -1 && trans[u][c] == x) { trans[u][c] = y; u = slink[u]; } minlen[y] = maxlen[slink[y]] + 1; return z; } int n; char s[N]; char f[200]; // 存映射 vector<int> ve; void deal(int &st){ rep(k,3) f[ve[k] + 'a'] = k + 'a'; rep(i,n){ st = add_char(f[s[i]], st); } } ll dp[N]; ll dfs(int st){ ll& ret = dp[st]; if(ret!=-1) return ret; ret = st!=0; // 不算状态0 rep(i,26){ if(trans[st][i]!=-1){ ret += dfs(trans[st][i]); } } return ret; } void work(){ idx=0; scanf("%s",s); n = strlen(s); ve.clear(); rep(i,3)ve.PB(i); int sta = new_state(0,0,NULL,-1); do{ deal(sta); sta = add_char(26 + 'a', sta); }while(next_permutation(ve.begin(), ve.end()));//全排列 rep(i,idx) dp[i] = -1; int cnt = -1; // 因为状态0是空串不算进去 for(int st = 0;st!=-1;st = trans[st][0], cnt++); // 计算同字符的串的数量(最大长度) printf("%lld\n", (dfs(0)+cnt*3)/6); } int main() { while(scd(n)==1) work(); }
全部评论
(0) 回帖