12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758 |
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os.path
- import tensorflow as tf
- import securimage_solver.config as config
- RECORD_DIR = config.RECORD_DIR
- TRAIN_FILE = config.TRAIN_FILE
- VALID_FILE = config.VALID_FILE
- IMAGE_WIDTH = config.IMAGE_WIDTH
- IMAGE_HEIGHT = config.IMAGE_HEIGHT
- CLASSES_NUM = config.CLASSES_NUM
- CHARS_NUM = config.CHARS_NUM
- def read_and_decode(filename_queue):
- reader = tf.compat.v1.TFRecordReader()
- _, serialized_example = reader.read(filename_queue)
- features = tf.compat.v1.parse_single_example(
- serialized_example,
- features={
- 'image_raw': tf.compat.v1.FixedLenFeature([], tf.compat.v1.string),
- 'label_raw': tf.compat.v1.FixedLenFeature([], tf.compat.v1.string),
- })
- image = tf.compat.v1.decode_raw(features['image_raw'], tf.compat.v1.int16)
- image.set_shape([IMAGE_HEIGHT * IMAGE_WIDTH])
- image = tf.compat.v1.cast(image, tf.compat.v1.float32) * (1. / 255) - 0.5
- reshape_image = tf.compat.v1.reshape(image, [IMAGE_HEIGHT, IMAGE_WIDTH, 1])
- label = tf.compat.v1.decode_raw(features['label_raw'], tf.compat.v1.uint8)
- label.set_shape([CHARS_NUM * CLASSES_NUM])
- reshape_label = tf.compat.v1.reshape(label, [CHARS_NUM, CLASSES_NUM])
- return tf.compat.v1.cast(reshape_image, tf.compat.v1.float32), tf.compat.v1.cast(reshape_label, tf.compat.v1.float32)
- def inputs(train, batch_size):
- filename = os.path.join(RECORD_DIR,
- TRAIN_FILE if train else VALID_FILE)
- with tf.compat.v1.name_scope('input'):
- filename_queue = tf.compat.v1.train.string_input_producer([filename])
- image, label = read_and_decode(filename_queue)
- if train:
- images, sparse_labels = tf.compat.v1.train.shuffle_batch([image, label],
- batch_size=batch_size,
- num_threads=6,
- capacity=2000 + 3 * batch_size,
- min_after_dequeue=2000)
- else:
- images, sparse_labels = tf.compat.v1.train.batch([image, label],
- batch_size=batch_size,
- num_threads=6,
- capacity=2000 + 3 * batch_size)
- return images, sparse_labels
|