lattice.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. package bn256
  2. import (
  3. "math/big"
  4. )
  5. var half = new(big.Int).Rsh(Order, 1)
  6. var curveLattice = &lattice{
  7. vectors: [][]*big.Int{
  8. {bigFromBase10("147946756881789319000765030803803410728"), bigFromBase10("147946756881789319010696353538189108491")},
  9. {bigFromBase10("147946756881789319020627676272574806254"), bigFromBase10("-147946756881789318990833708069417712965")},
  10. },
  11. inverse: []*big.Int{
  12. bigFromBase10("147946756881789318990833708069417712965"),
  13. bigFromBase10("147946756881789319010696353538189108491"),
  14. },
  15. det: bigFromBase10("43776485743678550444492811490514550177096728800832068687396408373151616991234"),
  16. }
  17. var targetLattice = &lattice{
  18. vectors: [][]*big.Int{
  19. {bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697761"), bigFromBase10("9931322734385697763"), bigFromBase10("9931322734385697764")},
  20. {bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("4965661367192848882"), bigFromBase10("-9931322734385697762")},
  21. {bigFromBase10("-9931322734385697762"), bigFromBase10("-4965661367192848881"), bigFromBase10("4965661367192848881"), bigFromBase10("-4965661367192848882")},
  22. {bigFromBase10("9931322734385697763"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881"), bigFromBase10("-4965661367192848881")},
  23. },
  24. inverse: []*big.Int{
  25. bigFromBase10("734653495049373973658254490726798021314063399421879442165"),
  26. bigFromBase10("147946756881789319000765030803803410728"),
  27. bigFromBase10("-147946756881789319005730692170996259609"),
  28. bigFromBase10("1469306990098747947464455738335385361643788813749140841702"),
  29. },
  30. det: new(big.Int).Set(Order),
  31. }
  32. type lattice struct {
  33. vectors [][]*big.Int
  34. inverse []*big.Int
  35. det *big.Int
  36. }
  37. // decompose takes a scalar mod Order as input and finds a short, positive decomposition of it wrt to the lattice basis.
  38. func (l *lattice) decompose(k *big.Int) []*big.Int {
  39. n := len(l.inverse)
  40. // Calculate closest vector in lattice to <k,0,0,...> with Babai's rounding.
  41. c := make([]*big.Int, n)
  42. for i := 0; i < n; i++ {
  43. c[i] = new(big.Int).Mul(k, l.inverse[i])
  44. round(c[i], l.det)
  45. }
  46. // Transform vectors according to c and subtract <k,0,0,...>.
  47. out := make([]*big.Int, n)
  48. temp := new(big.Int)
  49. for i := 0; i < n; i++ {
  50. out[i] = new(big.Int)
  51. for j := 0; j < n; j++ {
  52. temp.Mul(c[j], l.vectors[j][i])
  53. out[i].Add(out[i], temp)
  54. }
  55. out[i].Neg(out[i])
  56. out[i].Add(out[i], l.vectors[0][i]).Add(out[i], l.vectors[0][i])
  57. }
  58. out[0].Add(out[0], k)
  59. return out
  60. }
  61. func (l *lattice) Precompute(add func(i, j uint)) {
  62. n := uint(len(l.vectors))
  63. total := uint(1) << n
  64. for i := uint(0); i < n; i++ {
  65. for j := uint(0); j < total; j++ {
  66. if (j>>i)&1 == 1 {
  67. add(i, j)
  68. }
  69. }
  70. }
  71. }
  72. func (l *lattice) Multi(scalar *big.Int) []uint8 {
  73. decomp := l.decompose(scalar)
  74. maxLen := 0
  75. for _, x := range decomp {
  76. if x.BitLen() > maxLen {
  77. maxLen = x.BitLen()
  78. }
  79. }
  80. out := make([]uint8, maxLen)
  81. for j, x := range decomp {
  82. for i := 0; i < maxLen; i++ {
  83. out[i] += uint8(x.Bit(i)) << uint(j)
  84. }
  85. }
  86. return out
  87. }
  88. // round sets num to num/denom rounded to the nearest integer.
  89. func round(num, denom *big.Int) {
  90. r := new(big.Int)
  91. num.DivMod(num, denom, r)
  92. if r.Cmp(half) == 1 {
  93. num.Add(num, big.NewInt(1))
  94. }
  95. }