This documentation is automatically generated by online-judge-tools/verification-helper
#include "number/sqrt_mod.hpp"素数 $p$ と整数 $a$ に対して,$x^2 \equiv a \pmod{p}$ を満たす $x$ を求める.解が存在しない場合は $-1$ を返す.Tonelli-Shanks algorithm に基づく.
int a, p;
int x = sqrt_mod<int>(a, p); // x^2 ≡ a (mod p), or -1
long long al, pl;
long long xl = sqrt_mod<long long>(al, pl); // __int128 を内部で使用
Int が int のとき内部で long long,long long のとき __int128 を乗算のオーバーフロー回避に使用する.
#pragma once
#include <algorithm>
#include <type_traits>
// Solve x^2 == a (MOD p) for prime p
// If no solution exists, return -1
template <class Int> Int sqrt_mod(Int a, Int p) {
using Long =
std::conditional_t<std::is_same_v<Int, int>, long long,
std::conditional_t<std::is_same_v<Int, long long>, __int128, void>>;
auto pow = [&](Int x, long long n) {
Int ans = 1, tmp = x;
while (n) {
if (n & 1) ans = (Long)ans * tmp % p;
tmp = (Long)tmp * tmp % p, n /= 2;
}
return ans;
};
if (a == 0) return 0;
a = (a % p + p) % p;
if (p == 2) return a;
if (pow(a, (p - 1) / 2) != 1) return -1;
int b = 1;
while (pow(b, (p - 1) / 2) == 1) ++b;
int e = 0;
Int m = p - 1;
while (m % 2 == 0) m /= 2, ++e;
Int x = pow(a, (m - 1) / 2), y = (Long)x * x % p * a % p;
x = (Long)x * a % p;
Int z = pow(b, m);
while (y != 1) {
int j = 0;
Int t = y;
while (t != 1) t = (Long)t * t % p, ++j;
z = pow(z, 1LL << (e - j - 1));
x = (Long)x * z % p;
z = (Long)z * z % p;
y = (Long)y * z % p;
e = j;
}
return std::min(x, p - x);
}#line 2 "number/sqrt_mod.hpp"
#include <algorithm>
#include <type_traits>
// Solve x^2 == a (MOD p) for prime p
// If no solution exists, return -1
template <class Int> Int sqrt_mod(Int a, Int p) {
using Long =
std::conditional_t<std::is_same_v<Int, int>, long long,
std::conditional_t<std::is_same_v<Int, long long>, __int128, void>>;
auto pow = [&](Int x, long long n) {
Int ans = 1, tmp = x;
while (n) {
if (n & 1) ans = (Long)ans * tmp % p;
tmp = (Long)tmp * tmp % p, n /= 2;
}
return ans;
};
if (a == 0) return 0;
a = (a % p + p) % p;
if (p == 2) return a;
if (pow(a, (p - 1) / 2) != 1) return -1;
int b = 1;
while (pow(b, (p - 1) / 2) == 1) ++b;
int e = 0;
Int m = p - 1;
while (m % 2 == 0) m /= 2, ++e;
Int x = pow(a, (m - 1) / 2), y = (Long)x * x % p * a % p;
x = (Long)x * a % p;
Int z = pow(b, m);
while (y != 1) {
int j = 0;
Int t = y;
while (t != 1) t = (Long)t * t % p, ++j;
z = pow(z, 1LL << (e - j - 1));
x = (Long)x * z % p;
z = (Long)z * z % p;
y = (Long)y * z % p;
e = j;
}
return std::min(x, p - x);
}