captcha_input.py 2.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import os.path
  5. import tensorflow as tf
  6. import securimage_solver.config as config
  7. RECORD_DIR = config.RECORD_DIR
  8. TRAIN_FILE = config.TRAIN_FILE
  9. VALID_FILE = config.VALID_FILE
  10. IMAGE_WIDTH = config.IMAGE_WIDTH
  11. IMAGE_HEIGHT = config.IMAGE_HEIGHT
  12. CLASSES_NUM = config.CLASSES_NUM
  13. CHARS_NUM = config.CHARS_NUM
  14. def read_and_decode(filename_queue):
  15. reader = tf.compat.v1.TFRecordReader()
  16. _, serialized_example = reader.read(filename_queue)
  17. features = tf.compat.v1.parse_single_example(
  18. serialized_example,
  19. features={
  20. 'image_raw': tf.compat.v1.FixedLenFeature([], tf.compat.v1.string),
  21. 'label_raw': tf.compat.v1.FixedLenFeature([], tf.compat.v1.string),
  22. })
  23. image = tf.compat.v1.decode_raw(features['image_raw'], tf.compat.v1.int16)
  24. image.set_shape([IMAGE_HEIGHT * IMAGE_WIDTH])
  25. image = tf.compat.v1.cast(image, tf.compat.v1.float32) * (1. / 255) - 0.5
  26. reshape_image = tf.compat.v1.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, 1])
  27. label = tf.compat.v1.decode_raw(features['label_raw'], tf.compat.v1.uint8)
  28. label.set_shape([CHARS_NUM * CLASSES_NUM])
  29. reshape_label = tf.compat.v1.reshape(label, [CHARS_NUM, CLASSES_NUM])
  30. return tf.compat.v1.cast(reshape_image, tf.compat.v1.float32), tf.compat.v1.cast(reshape_label, tf.compat.v1.float32)
  31. def inputs(train, batch_size):
  32. filename = os.path.join(RECORD_DIR,
  33. TRAIN_FILE if train else VALID_FILE)
  34. with tf.compat.v1.name_scope('input'):
  35. filename_queue = tf.compat.v1.train.string_input_producer([filename])
  36. image, label = read_and_decode(filename_queue)
  37. if train:
  38. images, sparse_labels = tf.compat.v1.train.shuffle_batch([image, label],
  39. batch_size=batch_size,
  40. num_threads=6,
  41. capacity=2000 + 3 * batch_size,
  42. min_after_dequeue=2000)
  43. else:
  44. images, sparse_labels = tf.compat.v1.train.batch([image, label],
  45. batch_size=batch_size,
  46. num_threads=6,
  47. capacity=2000 + 3 * batch_size)
  48. return images, sparse_labels