From 943535eba59f2109100d4a54b36bf5843af8bbae Mon Sep 17 00:00:00 2001 From: Baoshuo Date: Wed, 30 Nov 2022 22:02:45 +0800 Subject: [PATCH] =?UTF-8?q?#108.=20=E5=A4=9A=E9=A1=B9=E5=BC=8F=E4=B9=98?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://loj.ac/s/1646626 --- LibreOJ/108/108.cpp | 51 +++++++++++++++++++++++++++++---------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/LibreOJ/108/108.cpp b/LibreOJ/108/108.cpp index cd624911..96950668 100644 --- a/LibreOJ/108/108.cpp +++ b/LibreOJ/108/108.cpp @@ -1,16 +1,28 @@ #include #include -#include -#include #include using std::cin; using std::cout; const char endl = '\n'; -const double PI = std::acos(-1); +const int mod = 998244353; -void fast_fourier_transform(std::valarray>& a) { +constexpr long long binpow(long long a, long long b) { + a %= mod; + + long long res = 1; + + while (b) { + if (b & 1) res = res * a % mod; + a = a * a % mod; + b >>= 1; + } + + return res; +} + +void number_theoretic_transform(std::valarray& a) { if (a.size() == 1) return; // assert(a.size() == 1 << std::__lg(a.size())); @@ -29,18 +41,20 @@ void fast_fourier_transform(std::valarray>& a) { } for (int len = 2; len <= a.size(); len <<= 1) { - std::complex wlen(std::cos(2 * PI / len), std::sin(2 * PI / len)); + int m = len >> 1; + + long long wlen = binpow(3, (mod - 1) / len); for (int i = 0; i < a.size(); i += len) { - std::complex w(1); + long long w = 1; - for (int j = 0; j < len / 2; j++) { - std::complex u = a[i + j], - v = a[i + j + len / 2] * w; + for (int j = 0; j < m; j++) { + long long u = a[i + j], + v = a[i + j + m] * w % mod; - a[i + j] = u + v; - a[i + j + len / 2] = u - v; - w *= wlen; + a[i + j] = (u + v) % mod; + a[i + j + m] = ((u - v) % mod + mod) % mod; + w = w * wlen % mod; } } } @@ -54,8 +68,9 @@ int main() { cin >> n >> m; - int k = 1 << (std::__lg(n + m) + 1); - std::valarray> f(k), g(k); + int k = 1 << (std::__lg(n + m) + 1), + inv = binpow(k, mod - 2); + std::valarray f(k), g(k); for (int i = 0; i <= n; i++) { cin >> f[i]; @@ -65,18 +80,18 @@ int main() { cin >> g[i]; } - fast_fourier_transform(f); - fast_fourier_transform(g); + number_theoretic_transform(f); + number_theoretic_transform(g); for (int i = 0; i < k; i++) { f[i] *= g[i]; } - fast_fourier_transform(f); + number_theoretic_transform(f); std::reverse(std::begin(f) + 1, std::end(f)); for (int i = 0; i <= n + m; i++) { - cout << static_cast(std::round(f[i].real() / k)) << ' '; + cout << f[i] * inv % mod << ' '; } cout << endl;