stringGenerator.lua 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. --- Random String Generator.
  2. -- Learns from provided strings, and generates similar strings.
  3. -- @module ROT.StringGenerator
  4. local ROT = require((...):gsub(('.[^./\\]*'):rep(1) .. '$', ''))
  5. local StringGenerator = ROT.Class:extend("StringGenerator")
  6. --- Constructor.
  7. -- Called with ROT.StringGenerator:new()
  8. -- @tparam table options A table with the following fields:
  9. -- @tparam[opt=false] boolean options.words Use word mode
  10. -- @tparam[opt=3] int options.order Number of letters/words to be used as context
  11. -- @tparam[opt=0.001] number options.prior A default priority for characters/words
  12. function StringGenerator:init(options)
  13. self._options = {words=false,
  14. order=3,
  15. prior=0.001
  16. }
  17. self._boundary=string.char(0)
  18. self._suffix =string.char(0)
  19. self._prefix ={}
  20. self._priorValues={}
  21. self._data ={}
  22. if options then
  23. for k,v in pairs(options) do
  24. self._options[k]=v
  25. end
  26. end
  27. for _=1,self._options.order do
  28. table.insert(self._prefix, self._boundary)
  29. end
  30. self._priorValues[self._boundary]=self._options.prior
  31. end
  32. --- Remove all learned data
  33. function StringGenerator:clear()
  34. self._data={}
  35. self._priorValues={}
  36. end
  37. --- Generate a string
  38. -- @treturn string The generated string
  39. function StringGenerator:generate()
  40. local result={self:_sample(self._prefix)}
  41. while result[#result] ~= self._boundary do
  42. table.insert(result, self:_sample(result))
  43. end
  44. table.remove(result)
  45. return table.concat(result)
  46. end
  47. --- Observe
  48. -- Learn from a string
  49. -- @tparam string s The string to observe
  50. function StringGenerator:observe(s)
  51. local tokens = self:_split(s)
  52. for i=1,#tokens do
  53. self._priorValues[tokens[i]] = self._options.prior
  54. end
  55. local i=1
  56. for _,v in pairs(self._prefix) do
  57. table.insert(tokens, i, v)
  58. i=i+1
  59. end
  60. table.insert(tokens, self._suffix)
  61. for i=self._options.order,#tokens-1 do
  62. local context=table.slice(tokens, i-self._options.order+1, i)
  63. local evt = tokens[i+1]
  64. for j=1,#context do
  65. local subcon=table.slice(context, j)
  66. self:_observeEvent(subcon, evt)
  67. end
  68. end
  69. end
  70. --- get Stats
  71. -- Get info about learned strings
  72. -- @treturn string Number of observed strings, number of contexts, number of possible characters/words
  73. function StringGenerator:getStats()
  74. local parts={}
  75. local prC=0
  76. for _ in pairs(self._priorValues) do
  77. prC = prC + 1
  78. end
  79. prC=prC-1
  80. table.insert(parts, 'distinct samples: '..prC)
  81. local dataC=0
  82. local evtCount=0
  83. for k,_ in pairs(self._data) do
  84. dataC=dataC+1
  85. for _,_ in pairs(self._data[k]) do
  86. evtCount=evtCount+1
  87. end
  88. end
  89. table.insert(parts, 'dict size(cons): '..dataC)
  90. table.insert(parts, 'dict size(evts): '..evtCount)
  91. return table.concat(parts, ', ')
  92. end
  93. function StringGenerator:_split(str)
  94. return str:split(self._options.words and " " or "")
  95. end
  96. function StringGenerator:_join(arr)
  97. return table.concat(arr, self._options.words and " " or "")
  98. end
  99. function StringGenerator:_observeEvent(context, event)
  100. local key=self:_join(context)
  101. if not self._data[key] then
  102. self._data[key] = {}
  103. end
  104. if not self._data[key][event] then
  105. self._data[key][event] = 0
  106. end
  107. self._data[key][event]=self._data[key][event]+1
  108. end
  109. function StringGenerator:_sample(context)
  110. context =self:_backoff(context)
  111. local key =self:_join(context)
  112. local data=self._data[key]
  113. local avail={}
  114. if self._options.prior then
  115. for k,_ in pairs(self._priorValues) do
  116. avail[k] = self._priorValues[k]
  117. end
  118. for k,_ in pairs(data) do
  119. avail[k] = avail[k]+data[k]
  120. end
  121. else
  122. avail=data
  123. end
  124. return self:_pickRandom(avail)
  125. end
  126. function StringGenerator:_backoff(context)
  127. local ctx = {}
  128. for i=1,#context do ctx[i]=context[i] end
  129. if #ctx > self._options.order then
  130. while #ctx > self._options.order do table.remove(ctx, 1) end
  131. elseif #ctx < self._options.order then
  132. while #ctx < self._options.order do table.insert(ctx,1,self._boundary) end
  133. end
  134. while not self._data[self:_join(ctx)] and #ctx>0 do
  135. ctx=table.slice(ctx, 2)
  136. end
  137. return ctx
  138. end
  139. function StringGenerator:_pickRandom(data)
  140. local total =0
  141. for k,_ in pairs(data) do
  142. total=total+data[k]
  143. end
  144. local rand=self._rng:random()*total
  145. local i=0
  146. for k,_ in pairs(data) do
  147. i=i+data[k]
  148. if (rand<i) then
  149. return k
  150. end
  151. end
  152. end
  153. return StringGenerator