weighted_dist.go 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246
  1. /*
  2. * Copyright (c) 2014, Yawning Angel <yawning at schwanenlied dot me>
  3. * All rights reserved.
  4. *
  5. * Redistribution and use in source and binary forms, with or without
  6. * modification, are permitted provided that the following conditions are met:
  7. *
  8. * * Redistributions of source code must retain the above copyright notice,
  9. * this list of conditions and the following disclaimer.
  10. *
  11. * * Redistributions in binary form must reproduce the above copyright notice,
  12. * this list of conditions and the following disclaimer in the documentation
  13. * and/or other materials provided with the distribution.
  14. *
  15. * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
  16. * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
  17. * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
  18. * ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
  19. * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
  20. * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
  21. * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
  22. * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
  23. * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
  24. * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
  25. * POSSIBILITY OF SUCH DAMAGE.
  26. */
  27. // Package probdist implements a weighted probability distribution suitable for
  28. // protocol parameterization. To allow for easy reproduction of a given
  29. // distribution, the drbg package is used as the random number source.
  30. package probdist // import "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/lyrebird/common/probdist"
  31. import (
  32. "bytes"
  33. "container/list"
  34. "fmt"
  35. "math/rand"
  36. "sync"
  37. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/lyrebird/common/csrand"
  38. "gitlab.torproject.org/tpo/anti-censorship/pluggable-transports/lyrebird/common/drbg"
  39. )
  40. const (
  41. minValues = 1
  42. maxValues = 100
  43. )
  44. // WeightedDist is a weighted distribution.
  45. type WeightedDist struct {
  46. sync.Mutex
  47. minValue int
  48. maxValue int
  49. biased bool
  50. values []int
  51. weights []float64
  52. alias []int
  53. prob []float64
  54. }
  55. // New creates a weighted distribution of values ranging from min to max
  56. // based on a HashDrbg initialized with seed. Optionally, bias the weight
  57. // generation to match the ScrambleSuit non-uniform distribution from
  58. // obfsproxy.
  59. func New(seed *drbg.Seed, min, max int, biased bool) (w *WeightedDist) {
  60. w = &WeightedDist{minValue: min, maxValue: max, biased: biased}
  61. if max <= min {
  62. panic(fmt.Sprintf("wDist.Reset(): min >= max (%d, %d)", min, max))
  63. }
  64. w.Reset(seed)
  65. return
  66. }
  67. // genValues creates a slice containing a random number of random values
  68. // that when scaled by adding minValue will fall into [min, max].
  69. func (w *WeightedDist) genValues(rng *rand.Rand) {
  70. nValues := (w.maxValue + 1) - w.minValue
  71. values := rng.Perm(nValues)
  72. if nValues < minValues {
  73. nValues = minValues
  74. }
  75. if nValues > maxValues {
  76. nValues = maxValues
  77. }
  78. nValues = rng.Intn(nValues) + 1
  79. w.values = values[:nValues]
  80. }
  81. // genBiasedWeights generates a non-uniform weight list, similar to the
  82. // ScrambleSuit prob_dist module.
  83. func (w *WeightedDist) genBiasedWeights(rng *rand.Rand) {
  84. w.weights = make([]float64, len(w.values))
  85. culmProb := 0.0
  86. for i := range w.weights {
  87. p := (1.0 - culmProb) * rng.Float64()
  88. w.weights[i] = p
  89. culmProb += p
  90. }
  91. }
  92. // genUniformWeights generates a uniform weight list.
  93. func (w *WeightedDist) genUniformWeights(rng *rand.Rand) {
  94. w.weights = make([]float64, len(w.values))
  95. for i := range w.weights {
  96. w.weights[i] = rng.Float64()
  97. }
  98. }
  99. // genTables calculates the alias and prob tables used for Vose's Alias method.
  100. // Algorithm taken from http://www.keithschwarz.com/darts-dice-coins/
  101. func (w *WeightedDist) genTables() {
  102. n := len(w.weights)
  103. var sum float64
  104. for _, weight := range w.weights {
  105. sum += weight
  106. }
  107. // Create arrays $Alias$ and $Prob$, each of size $n$.
  108. alias := make([]int, n)
  109. prob := make([]float64, n)
  110. // Create two worklists, $Small$ and $Large$.
  111. small := list.New()
  112. large := list.New()
  113. scaled := make([]float64, n)
  114. for i, weight := range w.weights {
  115. // Multiply each probability by $n$.
  116. p_i := weight * float64(n) / sum
  117. scaled[i] = p_i
  118. // For each scaled probability $p_i$:
  119. if scaled[i] < 1.0 {
  120. // If $p_i < 1$, add $i$ to $Small$.
  121. small.PushBack(i)
  122. } else {
  123. // Otherwise ($p_i \ge 1$), add $i$ to $Large$.
  124. large.PushBack(i)
  125. }
  126. }
  127. // While $Small$ and $Large$ are not empty: ($Large$ might be emptied first)
  128. for small.Len() > 0 && large.Len() > 0 {
  129. // Remove the first element from $Small$; call it $l$.
  130. l := small.Remove(small.Front()).(int)
  131. // Remove the first element from $Large$; call it $g$.
  132. g := large.Remove(large.Front()).(int)
  133. // Set $Prob[l] = p_l$.
  134. prob[l] = scaled[l]
  135. // Set $Alias[l] = g$.
  136. alias[l] = g
  137. // Set $p_g := (p_g + p_l) - 1$. (This is a more numerically stable option.)
  138. scaled[g] = (scaled[g] + scaled[l]) - 1.0
  139. if scaled[g] < 1.0 {
  140. // If $p_g < 1$, add $g$ to $Small$.
  141. small.PushBack(g)
  142. } else {
  143. // Otherwise ($p_g \ge 1$), add $g$ to $Large$.
  144. large.PushBack(g)
  145. }
  146. }
  147. // While $Large$ is not empty:
  148. for large.Len() > 0 {
  149. // Remove the first element from $Large$; call it $g$.
  150. g := large.Remove(large.Front()).(int)
  151. // Set $Prob[g] = 1$.
  152. prob[g] = 1.0
  153. }
  154. // While $Small$ is not empty: This is only possible due to numerical instability.
  155. for small.Len() > 0 {
  156. // Remove the first element from $Small$; call it $l$.
  157. l := small.Remove(small.Front()).(int)
  158. // Set $Prob[l] = 1$.
  159. prob[l] = 1.0
  160. }
  161. w.prob = prob
  162. w.alias = alias
  163. }
  164. // Reset generates a new distribution with the same min/max based on a new
  165. // seed.
  166. func (w *WeightedDist) Reset(seed *drbg.Seed) {
  167. // Initialize the deterministic random number generator.
  168. drbg, _ := drbg.NewHashDrbg(seed)
  169. rng := rand.New(drbg)
  170. w.Lock()
  171. defer w.Unlock()
  172. w.genValues(rng)
  173. if w.biased {
  174. w.genBiasedWeights(rng)
  175. } else {
  176. w.genUniformWeights(rng)
  177. }
  178. w.genTables()
  179. }
  180. // Sample generates a random value according to the distribution.
  181. func (w *WeightedDist) Sample() int {
  182. var idx int
  183. w.Lock()
  184. defer w.Unlock()
  185. // Generate a fair die roll from an $n$-sided die; call the side $i$.
  186. i := csrand.Intn(len(w.values))
  187. // Flip a biased coin that comes up heads with probability $Prob[i]$.
  188. if csrand.Float64() <= w.prob[i] {
  189. // If the coin comes up "heads," return $i$.
  190. idx = i
  191. } else {
  192. // Otherwise, return $Alias[i]$.
  193. idx = w.alias[i]
  194. }
  195. return w.minValue + w.values[idx]
  196. }
  197. // String returns a dump of the distribution table.
  198. func (w *WeightedDist) String() string {
  199. var buf bytes.Buffer
  200. buf.WriteString("[ ")
  201. for i, v := range w.values {
  202. p := w.weights[i]
  203. if p > 0.01 { // Squelch tiny probabilities.
  204. buf.WriteString(fmt.Sprintf("%d: %f ", v, p))
  205. }
  206. }
  207. buf.WriteString("]")
  208. return buf.String()
  209. }