captcha_api.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import sys
  5. from os import path, environ
  6. from datetime import datetime
  7. from PIL import Image
  8. import numpy as np
  9. import tensorflow as tf
  10. from tensorflow.python.platform import gfile
  11. from securimage_solver.captcha_model import *
  12. from securimage_solver.trim import trim
  13. import securimage_solver.config as config
  14. from io import BytesIO
  15. class CaptchaApi():
  16. def __init__(self):
  17. self.IMAGE_WIDTH = config.IMAGE_WIDTH
  18. self.IMAGE_HEIGHT = config.IMAGE_HEIGHT
  19. self.CHAR_SETS = config.CHAR_SETS
  20. self.CLASSES_NUM = config.CLASSES_NUM
  21. self.CHARS_NUM = config.CHARS_NUM
  22. self.checkpoint_path = path.join(path.dirname(__file__), 'captcha_train')
  23. environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
  24. environ['CUDA_VISIBLE_DEVICES'] = '-1'# Added by me
  25. def one_hot_to_texts(self, recog_result):
  26. texts = []
  27. for i in range(recog_result.shape[0]):
  28. index = recog_result[i]
  29. texts.append(''.join([self.CHAR_SETS[i] for i in index]))
  30. return texts
  31. def input_data(self, IMAGE_BYTES):
  32. images = np.zeros([1, self.IMAGE_HEIGHT*self.IMAGE_WIDTH], dtype='float32')
  33. image = Image.open(BytesIO(IMAGE_BYTES))
  34. image_ = trim(image)
  35. image.close()
  36. image = image_
  37. image_gray = image.convert('L')
  38. image_resize = image_gray.resize(size=(self.IMAGE_WIDTH,self.IMAGE_HEIGHT))
  39. input_img = np.array(image_resize, dtype='float32')
  40. input_img = np.multiply(input_img.flatten(), 1./255) - 0.5
  41. images[0,:] = input_img
  42. return images
  43. def predict(self, IMAGE_BYTES):
  44. with tf.compat.v1.Graph().as_default(), tf.compat.v1.device('/cpu:0'):
  45. input_images = self.input_data(IMAGE_BYTES)
  46. images = tf.compat.v1.constant(input_images)
  47. logits = inference(images, keep_prob=1)
  48. result = output(logits)
  49. saver = tf.compat.v1.train.Saver()
  50. sess = tf.compat.v1.Session()
  51. saver.restore(sess, tf.compat.v1.train.latest_checkpoint(self.checkpoint_path))
  52. recog_result = sess.run(result)
  53. sess.close()
  54. text = self.one_hot_to_texts(recog_result)
  55. return text[0]
  56. if __name__ == '__main__':
  57. capi = CaptchaApi()
  58. print(capi.predict("images/7NwHCn_141c1458-b5e4-439f-be01-8a8b30c6cbd8.png"))