modular_square_root_all_solutions.sf 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. #!/usr/bin/ruby
  2. # Find all solutions to the quadratic congruence:
  3. # x^2 = a (mod n)
  4. # Based on algorithm by Hugo van der Sanden:
  5. # https://github.com/danaj/Math-Prime-Util/pull/55
  6. func sqrtmod_all(a, n) {
  7. n = -n if (n < 0)
  8. n == 0 && return []
  9. n == 1 && return [0]
  10. a = (a % n)
  11. func sqrtmod_pk(a, p, k) {
  12. var pk = p**k
  13. if (p `divides` a) {
  14. if (pk `divides` a) {
  15. var low = p**(k >> 1)
  16. var high = (k.is_odd ? (low*p) : low)
  17. return (^low -> map {|i| high * i })
  18. }
  19. var a2 = a/p
  20. return [] if !(p `divides` a2)
  21. var pj = pk/p
  22. return __FUNC__(a2/p, p, k-2).map {|q|
  23. ^p -> map {|i| q*p + i*pj }...
  24. }
  25. }
  26. var q = sqrtmod(a, pk)
  27. q.is_nan && return []
  28. return [q, pk-q] if (p != 2)
  29. return [q] if (k == 1)
  30. return [q, pk-q] if (k == 2)
  31. var pj = p**(k-1)
  32. var q2 = (q * (pj-1))%pk
  33. return [q, pk-q, q2, pk-q2]
  34. }
  35. var roots = []
  36. n.factor_map {|p,k|
  37. sqrtmod_pk(a, p, k).map {|r| [r, p**k] }
  38. }.cartesian {|*a|
  39. roots << Math.chinese(a...)
  40. }
  41. return roots.sort
  42. }
  43. say sqrtmod_all(1, 8) #=> [1, 3, 5, 7]
  44. say sqrtmod_all(120, 5045) #=> [1165, 3880]
  45. say sqrtmod_all(4095, 8469) #=> [1110, 1713, 3933, 4536, 6756, 7359]
  46. # Run some tests
  47. assert_eq(
  48. sqrtmod_all(-1, 13**18 * 5**7)
  49. %n(633398078861605286438568 2308322911594648160422943 6477255756527023177780182 8152180589260066051764557)
  50. )
  51. assert_eq(
  52. sqrtmod_all(2466, 5967),
  53. [120 237 426 543 1446 1563 1752 1869 2109 2226 2415 2532 3435 3552 3741 3858 4098 4215 4404 4521 5424 5541 5730 5847]
  54. )
  55. assert_eq(
  56. sqrtmod_all(7281, 9954),
  57. %n(1233 1611 1707 2085 4551 4929 5025 5403 7869 8247 8343 8721)
  58. )
  59. assert_eq(
  60. sqrtmod_all(1701, 6300),
  61. %n[399, 651, 1449, 1701, 2499, 2751, 3549, 3801, 4599, 4851, 5649, 5901]
  62. )
  63. assert_eq(
  64. sqrtmod_all(306, 810),
  65. %n[66, 96, 174, 204, 336, 366, 444, 474, 606, 636, 714, 744]
  66. )
  67. assert_eq(
  68. sqrtmod_all(2754, 6561),
  69. %n[126, 603, 855, 1332, 1584, 2061, 2313, 2790, 3042, 3519, 3771, 4248, 4500, 4977, 5229, 5706, 5958, 6435]
  70. )
  71. assert_eq(
  72. sqrtmod_all(17640, 48465),
  73. %n[2865, 7905, 8250, 13290, 19020, 24060, 24405, 29445, 35175, 40215, 40560, 45600]
  74. )
  75. for n in (0..20), a in (^n) {
  76. var roots = sqrtmod_all(a, n)
  77. assert(roots.all {|r|
  78. r**2 % n == a
  79. }, "sqrtmod(#{a}, #{n}) = #{roots}")
  80. assert_eq(
  81. ^n -> grep {|k| mulmod(k, k, n) == a },
  82. roots
  83. )
  84. }