竞赛讨论区 > D: 快速数论变换,倍增。O(Plog^2P)
头像
thisislike_fan
发布于 04-08 12:36
+ 关注

D: 快速数论变换,倍增。O(Plog^2P)

#pragma GCC optimize("O2")
#pragma GCC optimize("O3")
#pragma GCC optimize("Ofast")
#pragma GCC optimize("unroll-loops")
#pragma GCC target("avx,avx2,fma")
#pragma GCC target("sse4,popcnt,abm,mmx")

#include <bits/stdc++.h>

#define out(x) cout << #x << '=' << (x) << endl
#define out2(x, y) cout << #x << '=' << (x) << ',' << #y << '=' << (y) << endl 
#define no do { cout << "No" << endl; return; } while(0)
#define yes do { cout << "Yes" << endl; return; } while (0)
#define lowbit(x) ((x) & -(x))

using namespace std;

using ll = long long;

const ll inf = 0x3f3f3f3f3f3f3f3fLL;
const int infi = 0x3f3f3f3f;

template<typename T> ostream & operator << (ostream &out,const set<T>&obj){out<<"set(";for(auto it=obj.begin();it!=obj.end();it++) out<<(it==obj.begin()?"":", ")<<*it;out<<")";return out;}
template<typename T1,typename T2> ostream & operator << (ostream &out,const map<T1,T2>&obj){out<<"map(";for(auto it=obj.begin();it!=obj.end();it++) out<<(it==obj.begin()?"":", ")<<it->first<<": "<<it->second;out<<")";return out;}
template<typename T1,typename T2> ostream & operator << (ostream &out,const pair<T1,T2>&obj){out<<"<"<<obj.first<<", "<<obj.second<<">";return out;}
template<typename T> ostream & operator << (ostream &out,const vector<T>&obj){out<<"vector(";for(auto it=obj.begin();it!=obj.end();it++) out<<(it==obj.begin()?"":", ")<<*it;out<<")";return out;}

const int maxn = 2e3;

const int P = 998244353, G = 3, Gi = 332748118;
 
int limit, r[maxn * 4], a[maxn * 4], b[maxn * 4];
 
inline int fastpow(int a, int k) {
    int r = 1;
    while (k) {
        if(k & 1) r = (1LL * r * a) % P;
        a = (1LL * a * a) % P;
        k >>= 1;
    }
    return r;
}
 
inline void NTT(int *A, int type) {
    for(int i = 0; i < limit; i++) 
        if(i < r[i]) swap(A[i], A[r[i]]);
    for(int mid = 1; mid < limit; mid <<= 1) { 
        int Wn = fastpow( type == 1 ? G : Gi , (P - 1) / (mid << 1));
        for(int j = 0; j < limit; j += (mid << 1)) {
            int w = 1;
            for(int k = 0; k < mid; k++, w = (1LL * w * Wn) % P) {
                 int x = A[j + k], y = (1LL * w * A[j + k + mid]) % P;
                 A[j + k] = (x + y) % P,
                 A[j + k + mid] = (x - y) % P;
                 if (A[j + k + mid] < 0) A[j + k + mid] += P;
            }
        }
    }
}
 
void ntt1() {
    NTT(a, 1); NTT(b, 1);
    for(int i = 0; i < limit; i++) a[i] = (1LL * a[i] * b[i]) % P;
    NTT(a, -1);
    int inv = fastpow(limit, P - 2);
    for(int i = 0; i < limit; i++) {
        a[i] = (1LL * a[i] * inv) % P;
    }
}

void ntt2() {
    NTT(a, 1);
    for(int i = 0; i < limit; i++) a[i] = (1LL * a[i] * a[i]) % P;
    NTT(a, -1);
    int inv = fastpow(limit, P - 2);
    for(int i = 0; i < limit; i++) {
        a[i] = (1LL * a[i] * inv) % P;
    }
}
 
 
void init(int N, int M) {
    int L = 0;
    limit = 1;
    while(limit <= N + M) limit <<= 1, L++;
    for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1)); 
}

int poly[20][maxn * 4];
int curr[maxn * 4];

void solve() {
    int n, p;
    cin >> n >> p;
    init(p - 1, p - 1);
    
    poly[0][0] = 1;
    for (int i = 1; i <= n; i++) {
        int a;
        cin >> a;
        poly[0][a % p] = 1;
        curr[a % p] = 1;
        if (a % p == 0) {
            cout << 1 << endl;
            return;
        }
    }
    for (int i = 1; i <= __lg(p - 1); i++) {
        memcpy(a, poly[i - 1], sizeof(int) * limit);
        ntt2();
        for (int j = 0; j <= (p - 1) * 2; j++) {
            poly[i][j % p] |= a[j] > 0;
        }
    }
    int ans = 1;
    for (int i = __lg(p - 1); i >= 0; i--) {
        memcpy(a, poly[i], sizeof(int) * limit);
        memcpy(b, curr, sizeof(int) * limit);
        ntt1();
        if (a[0] || a[p]) continue;
        for (int j = 0; j <= (p - 1) * 2; j++) {
            curr[j % p] |= a[j] > 0;
        }
        ans += 1 << i;
    }
    cout << ans + 1 << endl;
}

int main(void) {
    ios::sync_with_stdio(false);
    cin.tie(0); cout.tie(0);
    int t = 1;
	//cin >> t;
    
    while (t--) {
    	solve(); 
	}
}

全部评论

(0) 回帖
加载中...
话题 回帖

等你来战

查看全部

热门推荐