Hướng dẫn giải của duong3982oj Contest 04 - Miku đắt tiền
Nộp một lời giải chính thức trước khi tự giải là một hành động có thể bị ban.
Tác giả:
Subtask 1: ~n \le 6~.
Với subtask này, vì ~n \le 6~, nên số lượng số cần xét sẽ ~\le 10^6~. Ta chỉ cần duyệt qua mọi số và kiểm tra.
Độ phức tạp: ~O~ (~10^n \times n~).
Subtask 2: ~n \le 1000~.
Quy hoạch động: Gọi ~f[i][j][k]~ là có bao nhiêu số có ~i~ chữ số, trong đó có đúng ~j~ chữ số ~x~, và số dư khi chia cho ~3~ là ~k~.
Ta có công thức truy hồi: ~f[i][j][k]~ ~=~ ~f[i - 1][j][k - z] + 1~ với ~0 \le z \le 9~ và ~z \neq x~, và ~f[i][j][k]~ ~=~ ~f[i - 1][j - 1][k - z] + 1~ với ~z = x~.
Kết quả bài toán sẽ là ~\sum f[n][j][k]~, với ~l \le j \le r~.
Độ phức tạp: ~O~ (~n^2 \times 10 \times 3~).
Subtask 5: ~n \le 10^5~.
Vì ~r - l \le 10^4~, ta có thể thử duyệt ~p~ từ ~l~ tới ~r~, với ý nghĩa số ta xây dựng sẽ có đúng ~p~ chữ số ~x~.
Đến đây, ta cần làm hai bước.
- Bước ~1~: Ta sẽ chọn ra ~p~ vị trí trong số ~n~ vị trí để đặt chữ số ~x~ vào. Dễ thấy, số cách chọn là ~\binom {n}{p}~. Cách tính tổ hợp các bạn có thể tham khảo tại VNOI Wiki.
- Bước ~2~: Dễ thấy, ta sẽ cần phải chọn ra ~n - p~ số, sao cho số dư khi chia cho ~3~ là ~k - p \times x~. Điều này có thể thực hiện bằng quy hoạch động.
Độ phức tạp: ~O~ (~(r - l) \times log (n)~).
Subtask 6: ~n \le 10^9~.
Ta sẽ tối ưu hai bước trên.
- Với bước ~1~: ta cần có cách để tính giai thừa lên tới ~10^9~. Điều này có thể thực hiện bằng trick baby step giant step.
- Với bước ~2~: ta có thể sử dụng nhân ma trận để tăng tốc quy hoạch động.
Độ phức tạp: ~O~ (~(r - l) \times log (n) \times 3^3 \times log (n)~).
Code mẫu:
#include <bits/stdc++.h> using namespace std; #define int long long #define pp pair <int, int> #define fi first #define se second #define yes cout << "YES\n" #define no cout << "NO\n" const int N = 1e6 + 9; const int mod = 1e9 + 7; int n, k, x; int cur[] = {4, 3, 3}; void add (int &u, int v){ u += v; if (u >= mod) u -= mod; } vector <vector <int>> matrix_mult (const vector <vector <int>> &A, const vector <vector <int>> &B) { vector <vector <int>> result (3, vector <int> (3, 0)); for (int i = 0; i < 3; ++i) for (int j = 0; j < 3; ++j) for (int k = 0; k < 3; ++k) add (result[i][j], A[i][k] * B[k][j] % mod); return result; } vector <vector <int>> matrix_exponentiation (vector <vector <int>> base, int power){ vector <vector <int>> result = {{1, 0, 0}, {0, 1, 0}, {0, 0, 1}}; while (power){ if (power % 2 == 1) result = matrix_mult (result, base); base = matrix_mult (base, base); power >>= 1; } return result; } int solve (int i, int j){ if (i < 0) return 0; if (i == 0){ if (j % 3 == 0 && x != 0) return 1; return 0; } if (i == 1){ if (j == 0) return cur[0]; if (j == 1) return cur[1]; if (j == 2) return cur[2]; } vector <vector <int>> M = { {cur[0], cur[2], cur[1]}, {cur[1], cur[0], cur[2]}, {cur[2], cur[1], cur[0]} }; vector <vector <int>> M_exp = matrix_exponentiation (M, i - 1); vector <int> F1 = {cur[0], cur[1], cur[2]}; vector <int> Fi (3, 0); for (int row = 0; row < 3; ++row) for (int col = 0; col < 3; ++col) add (Fi[row], M_exp[row][col] * F1[col] % mod); return Fi[j]; } int POW (int a, int b, int p = mod){ a %= p; if (a == 0) return 0; int ret = 1; while (b > 0){ if (b & 1){ ret = (ret * a) % p; b--; } a *= a; a %= p; b >>= 1; } return ret; } int inv (int val){ return POW (val, mod - 2, mod); } int b[] = {}; // b[i] = (i * 1e6)! int fact (int n){ if (n > mod) return 0LL; int res = b[n / (1000000)]; for (int i = n / 1000000 * 1000000 + 1; i <= n; i++) res = res * i % mod; return res; } int comb (int n, int k){ if (k > n) return 0; if (k < 0) return 0; int r = n - k, res = 1; while (n > 0){ if (n % mod < k % mod) return 0; res *= fact (n % mod); res %= mod; res *= inv (fact (k % mod)); res %= mod; res *= inv (fact (r % mod)); res %= mod; n /= mod; k /= mod; r /= mod; } return res; } int diff (int u, int v){ u -= v; if (u < 0) u += mod; return u; } int l, r; int ok1[N], ok2[N]; signed main (){ ios_base::sync_with_stdio (false); cin.tie (NULL); cout.tie (NULL); if (fopen ("input.txt", "r")){ freopen ("input.txt", "r", stdin); freopen ("output.txt", "w", stdout); } cin >> n >> l >> r >> k >> x; cur[x % 3]--; int res = 0; ok1[0] = comb (n - 1, l); ok2[0] = comb (n, l); for (int i = l + 1; i <= r; i++){ if (i > n - 1) ok1[i - l] = 0; else ok1[i - l] = ok1[i - l - 1] * ((n - 1 - (i - 1)) % mod) % mod * inv (i) % mod; if (i > n) ok2[i - l] = 0; else ok2[i - l] = ok2[i - l - 1] * ((n - (i - 1)) % mod) % mod * inv (i) % mod; } for (int p = l; p <= r; p++){ int req = x * p; req %= 3; req = k - req; req %= 3; req += 3; req %= 3; int ways_non_first = ok1[p - l]; int ways_first = diff (ok2[p - l], ways_non_first); for (int dig = 1; dig <= 9; dig++){ if (dig == x) continue; int req_new = x * p + dig; req_new %= 3; req_new = k - req_new; req_new %= 3; req_new += 3; req_new %= 3; add (res, ways_non_first * solve (n - p - 1, req_new) % mod); } if (x) add (res, ways_first * solve (n - p, req) % mod); } cout << res; }
Bình luận