statistics.lua 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. local statistics = {}
  2. local ROOT_2 = math.sqrt(2.0)
  3. -- Approximations for erf(x) and erfInv(x) from
  4. -- https://en.wikipedia.org/wiki/Error_function
  5. local erf
  6. local erf_inv
  7. local A = 8 * (math.pi - 3.0)/(3.0 * math.pi * (4.0 - math.pi))
  8. local B = 4.0 / math.pi
  9. local C = 2.0/(math.pi * A)
  10. local D = 1.0 / A
  11. erf = function(x)
  12. if x == 0 then return 0; end
  13. local xSq = x * x
  14. local aXSq = A * xSq
  15. local v = math.sqrt(1.0 - math.exp(-xSq * (B + aXSq) / (1.0 + aXSq)))
  16. return (x > 0 and v) or -v
  17. end
  18. erf_inv = function(x)
  19. if x == 0 then return 0; end
  20. if x <= -1 or x >= 1 then return nil; end
  21. local y = math.log(1 - x * x)
  22. local u = C + 0.5 * y
  23. local v = math.sqrt(math.sqrt(u * u - D * y) - u)
  24. return (x > 0 and v) or -v
  25. end
  26. local function std_normal(u)
  27. return ROOT_2 * erf_inv(2.0 * u - 1.0)
  28. end
  29. local poisson
  30. local cdf_table = {}
  31. local function generate_cdf(lambda_index, lambda)
  32. local max = math.ceil(4 * lambda)
  33. local pdf = math.exp(-lambda)
  34. local cdf = pdf
  35. local t = { [0] = pdf }
  36. for i = 1, max - 1 do
  37. pdf = pdf * lambda / i
  38. cdf = cdf + pdf
  39. t[i] = cdf
  40. end
  41. return t
  42. end
  43. for li = 1, 100 do
  44. cdf_table[li] = generate_cdf(li, 0.25 * li)
  45. end
  46. poisson = function(lambda, max)
  47. if max < 2 then
  48. return (math.random() < math.exp(-lambda) and 0) or 1
  49. elseif lambda >= 2 * max then
  50. return max
  51. end
  52. local u = math.random()
  53. local lambda_index = math.floor(4 * lambda + 0.5)
  54. local cdfs = cdf_table[lambda_index]
  55. if cdfs then
  56. lambda = 0.25 * lambda_index
  57. if u < cdfs[0] then return 0; end
  58. if max > #cdfs then max = #cdfs + 1 else max = math.floor(max); end
  59. if u >= cdfs[max - 1] then return max; end
  60. if max > 4 then -- Binary search
  61. local s = 0
  62. while s + 1 < max do
  63. local m = math.floor(0.5 * (s + max))
  64. if u < cdfs[m] then max = m; else s = m; end
  65. end
  66. else
  67. for i = 1, max - 1 do
  68. if u < cdfs[i] then return i; end
  69. end
  70. end
  71. return max
  72. else
  73. local x = lambda + math.sqrt(lambda) * std_normal(u)
  74. return (x < 0.5 and 0) or (x >= max - 0.5 and max) or math.floor(x + 0.5)
  75. end
  76. end
  77. -- Error function.
  78. statistics.erf = erf
  79. -- Inverse error function.
  80. statistics.erf_inv = erf_inv
  81. --- Standard normal distribution function (mean 0, standard deviation 1).
  82. --
  83. -- @return
  84. -- Any real number (actually between -3.0 and 3.0).
  85. statistics.std_normal = function()
  86. local u = math.random()
  87. if u < 0.001 then
  88. return -3.0
  89. elseif u > 0.999 then
  90. return 3.0
  91. end
  92. return std_normal(u)
  93. end
  94. --- Standard normal distribution function (mean 0, standard deviation 1).
  95. --
  96. -- @param mu
  97. -- The distribution mean.
  98. -- @param sigma
  99. -- The distribution standard deviation.
  100. -- @return
  101. -- Any real number (actually between -3*sigma and 3*sigma).
  102. statistics.normal = function(mu, sigma)
  103. local u = math.random()
  104. if u < 0.001 then
  105. return mu - 3.0 * sigma
  106. elseif u > 0.999 then
  107. return mu + 3.0 * sigma
  108. end
  109. return mu + sigma * std_normal(u)
  110. end
  111. --- Poisson distribution function.
  112. --
  113. -- @param lambda
  114. -- The distribution mean and variance.
  115. -- @param max
  116. -- The distribution maximum.
  117. -- @return
  118. -- An integer between 0 and max (both inclusive).
  119. statistics.poisson = function(lambda, max)
  120. lambda, max = tonumber(lambda), tonumber(max)
  121. if not lambda or not max or lambda <= 0 or max < 1 then return 0; end
  122. return poisson(lambda, max)
  123. end
  124. return statistics