poj 3376 Finding Palindromes(扩展kmp+trie)
poj 3376 Finding Palindromes(扩展kmp+trie)
题意:给出n个字符串,问这n个字符串两两链接(一共有n^2中连接方法),组成的所有的字符串中,有多少个回文串。
好题!!!
解题思路:对于两个串连接是否是回文串,我们应该怎样去判断了?假如我们把其中一个翻转,若此时,短的那个串是长的那个的前缀,而长的那个串后面剩余的后缀恰好是个回文串,那这两个串连起来就是个回文串了。比如abc和abacba连接,我们把后一个串翻转,得到abcaba,abc为其前缀,而aba是个回文串,那么连起来就是个回文串了(这个规律找到就好办了)。那我们就把所有的串插入到trie中,然后再用所有的反串去匹配就行了。匹配的过程中,走到任意一个节点,而这个节点有可能是若干个串的结尾,那么此时我们就要判反串匹配位置下面剩余的部分是否回文。如果是的,ans就加上以这个节点为结尾的原串的个数(这个插入的时候就可以统计进去了)。如果走完了,还没走到叶子节点,那么就要看走到的节点下的子树(其实是以前面走过的路径为前缀的字符串剩下的一些后缀)有多少是回文的了(这个先预处理所有的串的后缀有哪些是回文的,然后在插入的时候统计到节点上)。剩下来一个问题就是如何在线性的时间内(或许o(nlogn)也可以吧,但我们有线性的算法,岂不更好?),这里我只是说,用扩展kmp能很合适。具体如何实现,留个小思考给大家(很简单的啦)。。
#include<stdio.h>#include<algorithm>#include<string.h>#define ll __int64#include<vector>using namespace std ;const int maxn = 2222222 ;char vec[maxn] ;int g[maxn] , nxt[maxn] ;bool li[maxn] ;int ok[maxn] , p[maxn] ;void get_p (const char *T){ int len=strlen(T),a=0; int i , k ; p[0]=len; while(a<len-1 && T[a]==T[a+1]) a++; p[1]=a; a=1; for( k=2;k<len;k++){ int fuck=a+p[a]-1,L=p[k-a]; if( (k-1)+L >= fuck){ int j = (fuck-k+1)>0 ? (fuck-k+1) : 0; while(k+j<len && T[k+j]==T[j]) j++; p[k]=j; a=k; } else p[k]=L; }}void match ( char *s , char *s1 ) { int len = strlen ( s ) , len1 = strlen ( s1 ) ; int i = 0 , k , j = 0 , a ; while ( i < len && j < len1 && s[i] == s1[j] ) i ++ , j ++ ; ok[0] = j ; a = 0 ; for ( k = 1 ; k < len ; k ++ ) { int fuck = a + ok[a] - 1 , l = p[k-a] ; if ( k + l - 1 >= fuck ) { int j = ( fuck - k + 1 ) > 0 ? ( fuck - k + 1 ) : 0 ; while ( k + j < len && j < len1 && s[k+j] == s1[j] ) j ++ ; ok[k] = j ; a = k ; } else ok[k] = l ; }}int tot = 0 , c[26][maxn] , cnt[maxn] , val[maxn] ;int new_node () { int i ; for ( i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ; cnt[tot] = val[tot] = 0 ; return tot ++ ;}void insert ( char *s ) { int len = strlen ( s ) , i , now = 0 ; for ( i = 0 ; i < len ; i ++ ) { int k = s[i] - 'a' ; if ( !c[k][now] ) c[k][now] = new_node () ; now = c[k][now] ; if ( i + 1 < len && ok[i+1] == len - i - 1 ) { cnt[now] ++ ; } }cnt[now] ++ ; val[now] ++ ;}ll ans = 0 ;void cal ( int len ) { int j , i , now = 0 ;li[len+1] = 1 ;//printf ( "len = %d\n" , len ) ;//for ( i = 1 ; i <= len ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ; for ( j = 1 ; j <= len ; j ++ ) {//printf ( "j = %d , ans = %I64d\n" , j , ans ) ;if ( li[j] ) now = 0 ; int k = vec[j] - 'a' ; if ( !c[k][now] ) { now = 0 ;//printf ( "nxt[%d] = %d\n" , j , nxt[j] ) ;j = nxt[j] - 1 ;continue ; } now = c[k][now] ; if ( !li[j+1] && g[j+1] ) ans += (ll) val[now] ;//printf ( "j = %d , now = %d\n" , j , now ) ;if ( li[j+1] ) {//if ( j == 10 ) printf ( "cnt[%d] = %d\n" , now , cnt[now] ) ;ans += (ll) cnt[now] ;now = 0 ;} }}char s1[maxn] , s[maxn] ;int main () { int n , i , j , k ; while ( scanf ( "%d" , &n ) != EOF ) { tot = 0 ; new_node () ;int t = 0 ; for ( i = 1 ; i <= n ; i ++ ) { scanf ( "%d%s" , &j , s ) ; strcpy ( s1 , s ) ; int len = strlen ( s ) ; reverse ( s1 , s1 + len ) ; get_p ( s1 ) ; match ( s , s1 ) ; insert ( s ) ; get_p ( s ) ; match ( s1 , s ) ;li[t+1] = 1 ; for ( j = 0 ; j < len ; j ++ ) { if ( ok[j] == len - j ) g[++t] = 1 ; else g[++t] = 0 ; vec[t] = s1[j] ;if ( j ) li[t] = 0 ; } } // for ( i = 1 ; i < t ; i ++ ) printf ( "%d " , cnt[i] ) ; puts ( "" ) ;int last = t + 1 ;//for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , li[i] ) ; puts ( "" ) ;for ( i = t ; i >= 1 ; i -- ) {nxt[i] = last ;if ( li[i] ) last = i ;}//for ( i = 1 ; i <= t ; i ++ ) printf ( "%d " , nxt[i] ) ; puts ( "" ) ; ans = 0 ; cal ( t ) ; printf ( "%I64d\n" , ans ) ; }}