top of page
Click here to go to the home page of AskTheCode.

Even Odd Partition Codechef July Long Challenge Solution | AskTheCode

Team ATC

Updated: Jul 15, 2021

Codechef July Long Challenge 2021 Solution | Even Odd Partition solution in C++ | AskTheCode

 

Problem:

Let f(n) be the number of ways to partition the array [1,2,3,…,n] into contiguous sub-arrays such that every pair of adjacent sub-arrays in the partition have sums of different parity.


  • What is a contiguous sub-array? A contiguous sub-array of an array A is an array that can be obtained by deleting some (possibly none) elements from the front of A and some (possibly none) elements from the end of A. The sub-arrays of an array A (1-indexed) of size n are given by [Ai, Ai+1, ..., Aj] for each pair of integers (i, j) such that 1<=i<=j<=n.

  • What is a partition of an array into contiguous sub-arrays? A partition of the array A into contiguous sub-arrays is a set of sub-arrays of A, {S1, S2, S3, ..., Sk}, such that every element of the array belongs to exactly one of the sub-arrays S1, S2, ..., Sk.

  • Which partitions are counted in f(n)? Consider a partition {S1,S2,…,Sk} of the array [1,2,…,n], where Si are in sorted order, i.e. every element in Si is smaller than every element in Sj for all i<j. Then f(n) is the number of partitions such that sum(Si)≢sum(Si+1)(mod2) for all 1≤i<k. sum(Si) is the sum of all elements in Si.

Let S0(n) = f(n) and Sk+1(n) = Sk(1)+Sk(2)+Sk(3)+⋯+Sk(n) for k ≥ 0.


Given n and k, find Sk(n)mod998244353.

 

Input:

  • The first line contains a single integer T, the number of test cases.

  • The first and only line of each test case contains two integers n, k.


Output:

  • For each test case print in a separate line, the value of Sk(n) mod 998 244 353.

 

Sample Input:

12
1 0
2 0
3 0
4 0
5 0
2 1
2 2
3 3
4 4
5 5
1000000000000000000 200
1000000000000000000 2773

Sample Output:

1
2
2
3
6
3
4
14
51
191
13413678
697825985
 

EXPLANATION:

We first find f(n) for n=1,2,3,4,5.


When n=1, there is only one partition [1]. Since there is only one sub array, there are no adjacent sub-arrays in this partition and hence should be counted in f(1), therefore f(1)=1.


When n=2, there are 2 partitions possible, [1,2] and [1],[2]. The first partition has no adjacent sub arrays and so is counted in f(2).


The second partition has sums 1,2 which are alternating in parity and so is also counted in f(2), therefore f(2)=2.


When n=3, there are 4 partitions possible. The partitions are shown below.

 

Code:

#include <bits/stdc++.h>
using namespace std;

typedef long long ll;
#define Mx 8200

int powmod(int x, int y, int mod) {
    int ret = 1;
    while (y) {
        if (y & 1) ret = 1ll * ret * x % mod;
        x = 1ll * x * x % mod, y >>= 1;
    }
    return ret;
}

const int mod = 998244353;

namespace ntt {
    const int root = 679720472;
    const int L = 14, FN = (1 << L);
    int rev[FN], nw[FN + 1], g[FN], h[FN];
    void init() {
        rev[0] = 0;
        for (int i = 1; i < FN; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
        nw[0] = nw[FN] = 1;
        for (int i = 1; i < FN; i ++) nw[i] = 1ll * nw[i - 1] * root % mod;
    }
    void dft(int *a, int n, int fg) {
        int d = __builtin_ctz(FN / n);
        for (int i = 0; i < n; i ++) if (i < (rev[i] >> d)) swap(a[i], a[rev[i] >> d]);
        for (int size = 2; size <= n; size <<= 1) {
            int step = FN / size;
            if (fg) step = -step;
            for (int i = 0; i < n; i += size) {
                int *u = a + i, *v = a + (i + (size >> 1)), *w = fg ? nw + FN : nw;
                for (int k = (size >> 1); k --;) {
                    int tmp = 1ll * (*v) * (*w) % mod;
                    *v = (*u + mod - tmp) % mod, *u = (*u + tmp) % mod;
                    u ++, v ++, w += step;
                }
            }
        }
        if (fg) {
            long long rev_n = powmod(n, mod - 2, mod);
            for (int i = 0; i < n; i ++) a[i] = (a[i] * rev_n) % mod;
        }
    }
    void multiply(int *a, int *b, int nn) {
        int n = 1 << (33 - __builtin_clz(nn));
        for (int i = 0; i < nn; i++) g[i] = a[i];
        for (int i = nn; i < n; i++) g[i] = 0;
        dft(g, n, 0);
        for (int i = 0; i < nn; i++) h[i] = b[i];
        for (int i = nn; i < n; i++) h[i] = 0;
        dft(h, n, 0);
        for (int i = 0; i < n; i ++) g[i] = 1ll * g[i] * h[i] % mod;
        dft(g, n, 1);
        for (int i = 0; i < nn; i++) a[i] = g[i];
    }
}
using ntt:: multiply;
using ntt:: init ;

int K, up;

struct Data{
    int a[3050], b[4][3050], c[5][5];
};

void multi(Data &A, Data B) {
    Data C;
    for (int i = 0; i < K; i++) C.a[i] = A.a[i];
    multiply(C.a, B.a, K);
    for (int j = 0; j < 4; j++) {
        for (int i = 0; i < K; i++) C.b[j][i] = B.b[j][i];
        multiply(C.b[j], A.a, K);
    }

    for (int i = 0; i < K; i++) for (int j = 0; j < 4; j++) for (int k = 0; k < 4; k++) {
        C.b[k][K - i - 1] = (1ll * A.b[j][K - i - 1] * B.c[j][k] + C.b[k][K - i - 1]) % mod;
    }
    for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) C.c[i][j] = 0;
    for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) for (int k = 0; k < 4; k++) {
        C.c[i][k] = (1ll * A.c[i][j] * B.c[j][k] + C.c[i][k]) % mod;
    }
    for (int i = 0; i < K; i++) A.a[i] = C.a[i];
    for (int i = 0; i < K; i++) for (int j = 0; j < 4; j++) A.b[j][i] = C.b[j][i];
    for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) A.c[i][j] = C.c[i][j];
}

Data rlt, x, y;
void modexp(ll n) {
    for (int i = 0; i < K; i++) rlt.a[i] = 0;
    for (int i = 0; i < 4; i++) for (int j = 0; j < K; j++) rlt.b[i][j] = 0;
    for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) rlt.c[i][j] = 0;
    for (int i = 0; i < K; i++) x.a[i] = 0;
    for (int i = 0; i < 4; i++) for (int j = 0; j < K; j++) x.b[i][j] = 0;
    for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) x.c[i][j] = 0;
    rlt.a[0] = 1;
    for (int i = 0; i < 4; i++) rlt.c[i][i] = 1;

    for (int i = 0; i < K; i++) x.a[i] = i + 1;
    for (int i = 0; i < 2; i++) for (int j = 0; j < K; j++) x.b[i][j] = 1;
    for (int i = 2; i < 4; i++) for (int j = 0; j < K; j++) x.b[i][j] = j + 1;
    x.c[0][2] = x.c[0][3] = x.c[2][0] = x.c[1][1] = x.c[2][1] = x.c[1][3] = x.c[3][1] = 1;
    x.c[1][0] = 2;
    while (n) {
        if (n & 1) multi(rlt, x);
        multi(x, x), n >>= 1;
    }
}

int A[3050][3050], B[3050], C[3050], S[5][3050];

int presolve(ll n) {
    int ans = 0;
    int rt[5][5], T[5][5], tp[5][5]; ll nn = (n - 3) / 2;
    for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) rt[i][j] = T[i][j] = 0;
    for (int i = 0; i < 4; i++) rt[i][i] = 1;
    T[0][2] = T[0][3] = T[1][3] = T[1][1] = T[2][0] = T[2][1] = T[3][1] = 1;
    T[1][0] = 2;

    while (nn) {
        if (nn & 1) {
            for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) tp[i][j] = 0;
            for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) for (int k = 0; k < 4; k++) {
                tp[i][k] += 1ll * rt[i][j] * T[j][k] % mod;
                if (tp[i][k] >= mod) tp[i][k] -= mod;
            }
            for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) rt[i][j] = tp[i][j];
        }
        for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) tp[i][j] = 0;
        for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) for (int k = 0; k < 4; k++) {
            tp[i][k] += 1ll * T[i][j] * T[j][k] % mod;
            if (tp[i][k] >= mod) tp[i][k] -= mod;
        }
        for (int i = 0; i < 4; i++) for (int j = 0; j < 4; j++) T[i][j] = tp[i][j];
        nn >>= 1;
    }
    for (int i = 0; i < 4; i++) for (int j = 1; j < 4; j++) rt[i][0] = (rt[i][0] + rt[i][j]) % mod;
    if (n % 2 == 0) ans = (rt[0][0] % mod + rt[1][0] * 2 % mod) % mod;
    else ans = (rt[0][0] + rt[1][0]) % mod;
    return ans;
}

void solve() {
    ll n;
    scanf("%lld%d", &n, &K);
    if (n < 5) {
        S[1][0] = 1, S[2][0] = 2, S[3][0] = 2, S[4][0] = 3, S[5][0] = 6;
        for (int i = 1; i <= n; i++) for (int j = 1; j <= K; j++) {
            S[i][j] = 0;
            for (int k = 1; k <= i; k++) {
                S[i][j] += S[k][j - 1];
                if (S[i][j] >= mod) S[i][j] -= mod;
            }
        }
        printf("%d\n", S[n][K]);
        return;
    }
    int ans = presolve(n);
    if (K == 0) {
        printf("%d\n", ans);
        return;
    }

    up = 1 << (33 - __builtin_clz(K));
    modexp((n - 1) / 2);
    for (int i = 0; i < K; i++) for (int j = 0; j < K; j++) {
        if (i > j) A[i][j] = 0;
        else A[i][j] = rlt.a[j - i];
    }
    for (int i = K; i < K + 4; i++) for (int j = 0; j < K; j++) A[i][j] = 0;
    for (int i = 0; i < K; i++) for (int j = K; j < K + 4; j++) A[i][j] = rlt.b[j - K][K - i - 1];
    for (int i = K; i < K + 4; i++) for (int j = K; j < K + 4; j++) A[i][j] = rlt.c[i - K][j - K];
    for (int i = 0; i < K + 4; i++) {
        C[i] = 0;
        for (int j = 0; j < K + 4; j++) {
            C[i] += A[i][j];
            if (C[i] >= mod) C[i] -= mod;
        }
    }

    if (n & 1) printf("%d\n", C[0]);
    else {
        for (int i = 0; i < K; i++) {
            ans += C[i];
            if (ans >= mod) ans -= mod;
        }
        printf("%d\n", ans);
    }
}

int main() {
    init();
    int tc;
    scanf("%d", &tc);
    while (tc--) solve();
    return 0;
}

Recent Posts

See All

Comments


bottom of page