Hướng dẫn giải của duong3982oj Contest 04 - Miku đắt tiền


Chỉ dùng lời giải này khi không có ý tưởng, và đừng copy-paste code từ lời giải này. Hãy tôn trọng người ra đề và người viết lời giải.
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.

Tác giả: duong3982

Subtask 1: ~n \le 6~.

Với subtask này, vì ~n \le 6~, nên số lượng số cần xét sẽ ~\le 10^6~. Ta chỉ cần duyệt qua mọi số và kiểm tra.

Độ phức tạp: ~O~ (~10^n \times n~).

Subtask 2: ~n \le 1000~.

Quy hoạch động: Gọi ~f[i][j][k]~ là có bao nhiêu số có ~i~ chữ số, trong đó có đúng ~j~ chữ số ~x~, và số dư khi chia cho ~3~ là ~k~.

Ta có công thức truy hồi: ~f[i][j][k]~ ~=~ ~f[i - 1][j][k - z] + 1~ với ~0 \le z \le 9~ và ~z \neq x~, và ~f[i][j][k]~ ~=~ ~f[i - 1][j - 1][k - z] + 1~ với ~z = x~.

Kết quả bài toán sẽ là ~\sum f[n][j][k]~, với ~l \le j \le r~.

Độ phức tạp: ~O~ (~n^2 \times 10 \times 3~).

Subtask 5: ~n \le 10^5~.

Vì ~r - l \le 10^4~, ta có thể thử duyệt ~p~ từ ~l~ tới ~r~, với ý nghĩa số ta xây dựng sẽ có đúng ~p~ chữ số ~x~.

Đến đây, ta cần làm hai bước.

  • Bước ~1~: Ta sẽ chọn ra ~p~ vị trí trong số ~n~ vị trí để đặt chữ số ~x~ vào. Dễ thấy, số cách chọn là ~\binom {n}{p}~. Cách tính tổ hợp các bạn có thể tham khảo tại VNOI Wiki.
  • Bước ~2~: Dễ thấy, ta sẽ cần phải chọn ra ~n - p~ số, sao cho số dư khi chia cho ~3~ là ~k - p \times x~. Điều này có thể thực hiện bằng quy hoạch động.

Độ phức tạp: ~O~ (~(r - l) \times log (n)~).

Subtask 6: ~n \le 10^9~.

Ta sẽ tối ưu hai bước trên.

  • Với bước ~1~: ta cần có cách để tính giai thừa lên tới ~10^9~. Điều này có thể thực hiện bằng trick baby step giant step.
  • Với bước ~2~: ta có thể sử dụng nhân ma trận để tăng tốc quy hoạch động.

Độ phức tạp: ~O~ (~(r - l) \times log (n) \times 3^3 \times log (n)~).

Code mẫu:

#include <bits/stdc++.h>
using namespace std;

#define int long long
#define pp pair <int, int>
#define fi first
#define se second
#define yes cout << "YES\n"
#define no cout << "NO\n"

const int N = 1e6 + 9;
const int mod = 1e9 + 7;

int n, k, x;

int cur[] = {4, 3, 3};

void add (int &u, int v){
    u += v;
    if (u >= mod) u -= mod;
}

vector <vector <int>> matrix_mult (const vector <vector <int>> &A, const vector <vector <int>> &B) {
    vector <vector <int>> result (3, vector <int> (3, 0));
    for (int i = 0; i < 3; ++i) for (int j = 0; j < 3; ++j) for (int k = 0; k < 3; ++k) add (result[i][j], A[i][k] * B[k][j] % mod);
    return result;
}

vector <vector <int>> matrix_exponentiation (vector <vector <int>> base, int power){
    vector <vector <int>> result = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}; 
    while (power){
        if (power % 2 == 1) result = matrix_mult (result, base);
        base = matrix_mult (base, base);
        power >>= 1;
    }
    return result;
}

int solve (int i, int j){
    if (i < 0) return 0;
    if (i == 0){
        if (j % 3 == 0 && x != 0) return 1;
        return 0;
    }
    if (i == 1){
        if (j == 0) return cur[0];
        if (j == 1) return cur[1];
        if (j == 2) return cur[2];
    }
    vector <vector <int>> M = {
        {cur[0], cur[2], cur[1]},
        {cur[1], cur[0], cur[2]},
        {cur[2], cur[1], cur[0]}
    };
    vector <vector <int>> M_exp = matrix_exponentiation (M, i - 1);
    vector <int> F1 = {cur[0], cur[1], cur[2]};
    vector <int> Fi (3, 0);
    for (int row = 0; row < 3; ++row) for (int col = 0; col < 3; ++col) add (Fi[row], M_exp[row][col] * F1[col] % mod);
    return Fi[j];
}

int POW (int a, int b, int p = mod){
    a %= p;
    if (a == 0) return 0;
    int ret = 1;
    while (b > 0){
        if (b & 1){
            ret = (ret * a) % p; b--;
        }
        a *= a; a %= p; b >>= 1;
    }
    return ret;
}

int inv (int val){
    return POW (val, mod - 2, mod);
}

int b[] = {}; // b[i] = (i * 1e6)!

int fact (int n){
    if (n > mod) return 0LL;
    int res = b[n / (1000000)];
    for (int i = n / 1000000 * 1000000 + 1; i <= n; i++) res = res * i % mod;
    return res;
}

int comb (int n, int k){
    if (k > n) return 0;
    if (k < 0) return 0;
    int r = n - k, res = 1;
    while (n > 0){
        if (n % mod < k % mod) return 0;
        res *= fact (n % mod); res %= mod;
        res *= inv (fact (k % mod)); res %= mod;
        res *= inv (fact (r % mod)); res %= mod;
        n /= mod;
        k /= mod;
        r /= mod;
    }    
    return res;
}

int diff (int u, int v){
    u -= v;
    if (u < 0) u += mod;
    return u;
}

int l, r;

int ok1[N], ok2[N];

signed main (){
    ios_base::sync_with_stdio (false);
    cin.tie (NULL);
    cout.tie (NULL);
    if (fopen ("input.txt", "r")){
        freopen ("input.txt", "r", stdin);
        freopen ("output.txt", "w", stdout);
    }
    cin >> n >> l >> r >> k >> x;
    cur[x % 3]--;
    int res = 0;
    ok1[0] = comb (n - 1, l);
    ok2[0] = comb (n, l);
    for (int i = l + 1; i <= r; i++){
        if (i > n - 1) ok1[i - l] = 0;
        else ok1[i - l] = ok1[i - l - 1] * ((n - 1 - (i - 1)) % mod) % mod * inv (i) % mod;
        if (i > n) ok2[i - l] = 0;
        else ok2[i - l] = ok2[i - l - 1] * ((n - (i - 1)) % mod) % mod * inv (i) % mod;
    }
    for (int p = l; p <= r; p++){
        int req = x * p; req %= 3;
        req = k - req;
        req %= 3; req += 3; req %= 3;
        int ways_non_first = ok1[p - l];
        int ways_first = diff (ok2[p - l], ways_non_first);
        for (int dig = 1; dig <= 9; dig++){
            if (dig == x) continue;
            int req_new = x * p + dig;
            req_new %= 3;
            req_new = k - req_new;
            req_new %= 3; req_new += 3; req_new %= 3;
            add (res, ways_non_first * solve (n - p - 1, req_new) % mod);
        }
        if (x) add (res, ways_first * solve (n - p, req) % mod);
    }
    cout << res;
}

Bình luận

Hãy đọc nội quy trước khi bình luận.


Không có bình luận tại thời điểm này.