bayes.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import random
  2. import numpy as np
  3. def get_codes(data):
  4. #arguments
  5. # data list of text,author code tuples
  6. #returns
  7. # codes list of distinct author codes
  8. codes = []
  9. for tuples in data:
  10. codes.append(tuples[1])
  11. codes = set(codes)
  12. return codes
  13. def traintestsplit(data,testnum,seed=None):
  14. #arguments
  15. # data list of text,author code tuples
  16. # testnum int number of test items
  17. #returns
  18. # trainpairs list list of training pairs
  19. # testpairs list list of test pairs
  20. if seed != None: random.seed(seed)
  21. random.shuffle(data) # shuffles data set
  22. return data[testnum:], data[:testnum]
  23. def getcatlogprobs(train_data):
  24. #arguments
  25. # train_data list list of text,code tuples
  26. #returns
  27. # code_logprob dict dictionary from codes to logprobs
  28. # Log Probability can be defined as: log(number of times a text by stated author appears / total number of texts )
  29. # Calculate how often an author appears
  30. author_counts = {}
  31. for text, author in train_data:
  32. if author in author_counts:
  33. author_counts[author] += 1
  34. else:
  35. author_counts[author] = 1
  36. # Now we can calculate the log probabilities
  37. # The log probability gives us a larger and easier to utilize number than regular probability
  38. log_probs = {}
  39. total_num_of_pairs = len(train_data)
  40. for author, count in author_counts.items():
  41. probability = count / total_num_of_pairs
  42. if probability != 0:
  43. log_probs[author] = np.log(probability)
  44. return log_probs
  45. def getwords(listofpairs):
  46. #arguments
  47. # listofpairs list list of text,code tuples
  48. #returns
  49. # d dict dictionary of codes [authors, etc.] to dictionaries of word counts [example is the: 27]
  50. d = {}
  51. for text, code in listofpairs:
  52. words = text.split()
  53. # create d (return value)
  54. # Add author/code if not yet in d (dictionary)
  55. if code not in d: d[code] = {}
  56. # Collect how many times word is stated by author
  57. for word in words:
  58. if word in d[code]: d[code][word] += 1
  59. else: d[code][word] = 1
  60. return d
  61. def get_vocab(listofpairs):
  62. # gets all words in texts and returns list of vocab
  63. vocab = set()
  64. for text, code in listofpairs:
  65. words = text.split()
  66. vocab.update(words)
  67. return vocab
  68. def addone(cnts, vocab):
  69. #arguments
  70. # cnts dictionary of codes [authors, etc.] to dictionaries of word counts [example is the: 27]
  71. #returns
  72. # cnts dictionary of codes [authors, etc.] to dictionaries of word counts [example is the: 27] (with addone smoothing)
  73. #
  74. # Adds one to each value to ensure no word is equivlant to zero
  75. for code in cnts:
  76. for word in vocab:
  77. if word in cnts[code]:
  78. cnts[code][word] += 1
  79. else:
  80. cnts[code][word] = 1 # perhaps this should equal zero
  81. return cnts
  82. def logprobwords(cnts):
  83. #arguments
  84. # cnts dictionary of codes [authors, etc.] to dictionaries of word counts [example is the: 27]
  85. #returns
  86. # cnts dictionary of codes [authors, etc.] to dictionaries of the log probability of a word
  87. # The log probability is defined as: log(instances of word by author / total words by author)
  88. for code in cnts:
  89. total_count = sum(cnts[code].values()) # sum of all the values of the word count in dictinary: AKA the total number of words
  90. for word in cnts[code]:
  91. proportion = cnts[code][word] / total_count
  92. cnts[code][word] = np.log(proportion)
  93. return cnts
  94. def testone(catlogs,catwordlogs,pair):
  95. #arguments
  96. # catlogs dict the log probability of each author (how often they show up in stated data set)
  97. # catwordlogs dict thge log probability of each term for each author
  98. # words mapped to log probabilities
  99. # pair tuple a single text,code tuple
  100. #returns
  101. # y int the correct category (the true author)
  102. # yhat int the predicted category (the predicted author)
  103. text, y = pair # sets the correct category as y and the text as text
  104. words = text.split()
  105. scores = {}
  106. # Initialize dict of scores: highest score will be predicted author/category
  107. for code in catlogs:
  108. scores[code] = catlogs[code] # default log probability to the prominence of a certain category in data set
  109. for word in words:
  110. scores[code] += catwordlogs[code][word]
  111. yhat = max(scores, key=scores.get) # review syntax for .get()
  112. return y, yhat
  113. def bayes(data, trainsplit=100):
  114. test_data, train_data = traintestsplit(data, trainsplit)
  115. correct = 0
  116. total = len(train_data)
  117. #train_data = data
  118. vocab = get_vocab(data)
  119. #vocab = list(get_vocab(train_data))
  120. # Set catlogs, catwordlogs, and pairs
  121. catlogs = getcatlogprobs(train_data)
  122. catwordlogs = getwords(train_data)
  123. # Error Checking: checks if addone functions properly
  124. random_author = random.choice(list(catwordlogs.keys()))
  125. random_word = random.choice(list(catwordlogs[random_author].keys()))
  126. before_val = catwordlogs[random_author][random_word]
  127. catwordlogs = addone(catwordlogs, vocab)
  128. if before_val != catwordlogs[random_author][random_word] - 1: print("error") # prints error if it hasn't grown by one
  129. # End of Error Checking
  130. catwordlogs = logprobwords(catwordlogs)
  131. for pair in train_data:
  132. actual, predicted = testone(catlogs, catwordlogs, pair)
  133. if actual == predicted: correct += 1
  134. return correct, total
  135. def get_data(file_number):
  136. # Open File
  137. file_number = int(file_number)
  138. data_files = {1: "Gungor_2018_VictorianAuthorAttribution_data-train.csv"}
  139. dir = 'data/'
  140. print(f"opening: {dir}{data_files[file_number]}")
  141. f = open(dir+data_files[file_number],'r',encoding='ISO-8859-1')
  142. t = f.read()
  143. f.close()
  144. lines = t.split('\n')[1:-1]
  145. if file_number == 1:
  146. #lines = t.split('\n')[1:-1]
  147. data = [tuple(line.split(',')) for line in lines]
  148. return data
  149. def main():
  150. print("0: Test , 1: Gungor_2018_VictorianAuthorAttribution_data-train.csv")
  151. file_number = input("Select File Number:")
  152. print("Enter number for trainsplit (recommended value is 100):")
  153. training_num = int(input("Enter Number:"))
  154. if int(file_number) != 0: print(bayes(get_data(file_number), training_num))
  155. else: test_bayes(get_data(1))
  156. main()
  157. # Some tests provided for the data set Gungor_2018_VictorianAuthorAttribution_data-train.csv
  158. def test_bayes(data, test_dopairs=True):
  159. res = data
  160. print(len(res) == 53678)
  161. print(res[3][1] == '1')
  162. codes = get_codes(data)
  163. print(len(codes) == 45)
  164. print('2' in codes)
  165. #random.seed(1234)
  166. trainps,testps = traintestsplit(data,100,1234)
  167. vocab = get_vocab(trainps)
  168. print(len(trainps) + len(testps) == len(res))
  169. print(len(testps) == 100)
  170. lps = getcatlogprobs(trainps)
  171. print(len(lps) == 45)
  172. print(np.abs(lps['3'] + 5.5276) < .1)
  173. print(np.abs(lps['40'] + 4.82743) < .1)
  174. counts = getwords(trainps)
  175. print(np.abs(len(counts['8']) - 9994) < 2)
  176. print(np.abs(counts['9']['there'] - 3259) < 2)
  177. print(np.abs(counts['19']['apple'] - 42) < 2)
  178. initialcount = counts['3']['the']
  179. counts = addone(counts, vocab)
  180. print(counts['3']['the'] == initialcount + 1)
  181. print(np.abs(counts['19']['apple'] - 43) < 2)
  182. counts = logprobwords(counts)
  183. print(np.abs(counts['19']['apple'] + 10.494) < .1)
  184. print(np.abs(counts['41']['cats'] + 13.733) < .1)
  185. def dopairs(catlogs,catwordlogs,manypairs):
  186. correct = 0
  187. total = len(manypairs)
  188. for pair in manypairs:
  189. actual, predicted = testone(catlogs, catwordlogs, pair)
  190. if actual == predicted:
  191. correct += 1
  192. return correct, total
  193. if test_dopairs==True: print(dopairs(lps,counts,testps) == (81,100))
  194. print("End")