captcha_model.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120
  1. from __future__ import absolute_import
  2. from __future__ import division
  3. from __future__ import print_function
  4. import tensorflow as tf
  5. import securimage_solver.captcha_input
  6. import securimage_solver.config as config
  7. IMAGE_WIDTH = config.IMAGE_WIDTH
  8. IMAGE_HEIGHT = config.IMAGE_HEIGHT
  9. CLASSES_NUM = config.CLASSES_NUM
  10. CHARS_NUM = config.CHARS_NUM
  11. def inputs(train, batch_size):
  12. return captcha_input.inputs(train, batch_size=batch_size)
  13. def _conv2d(value, weight):
  14. """conv2d returns a 2d convolution layer with full stride."""
  15. return tf.compat.v1.nn.conv2d(value, weight, strides=[1, 1, 1, 1], padding='SAME')
  16. def _max_pool_2x2(value, name):
  17. """max_pool_2x2 downsamples a feature map by 2X."""
  18. return tf.compat.v1.nn.max_pool(value, ksize=[1, 2, 2, 1],
  19. strides=[1, 2, 2, 1], padding='SAME', name=name)
  20. def _weight_variable(name, shape):
  21. """weight_variable generates a weight variable of a given shape."""
  22. with tf.compat.v1.device('/cpu:0'):
  23. initializer = tf.compat.v1.truncated_normal_initializer(stddev=0.1)
  24. var = tf.compat.v1.get_variable(name,shape,initializer=initializer, dtype=tf.compat.v1.float32)
  25. return var
  26. def _bias_variable(name, shape):
  27. """bias_variable generates a bias variable of a given shape."""
  28. with tf.compat.v1.device('/cpu:0'):
  29. initializer = tf.compat.v1.constant_initializer(0.1)
  30. var = tf.compat.v1.get_variable(name, shape, initializer=initializer,dtype=tf.compat.v1.float32)
  31. return var
  32. def inference(images, keep_prob):
  33. images = tf.compat.v1.reshape(images, [-1, IMAGE_HEIGHT, IMAGE_WIDTH, 1])
  34. with tf.compat.v1.variable_scope('conv1') as scope:
  35. kernel = _weight_variable('weights', shape=[3,3,1,64])
  36. biases = _bias_variable('biases',[64])
  37. pre_activation = tf.compat.v1.nn.bias_add(_conv2d(images, kernel),biases)
  38. conv1 = tf.compat.v1.nn.relu(pre_activation, name=scope.name)
  39. pool1 = _max_pool_2x2(conv1, name='pool1')
  40. with tf.compat.v1.variable_scope('conv2') as scope:
  41. kernel = _weight_variable('weights', shape=[3,3,64,64])
  42. biases = _bias_variable('biases',[64])
  43. pre_activation = tf.compat.v1.nn.bias_add(_conv2d(pool1, kernel),biases)
  44. conv2 = tf.compat.v1.nn.relu(pre_activation, name=scope.name)
  45. pool2 = _max_pool_2x2(conv2, name='pool2')
  46. with tf.compat.v1.variable_scope('conv3') as scope:
  47. kernel = _weight_variable('weights', shape=[3,3,64,64])
  48. biases = _bias_variable('biases',[64])
  49. pre_activation = tf.compat.v1.nn.bias_add(_conv2d(pool2, kernel),biases)
  50. conv3 = tf.compat.v1.nn.relu(pre_activation, name=scope.name)
  51. pool3 = _max_pool_2x2(conv3, name='pool3')
  52. with tf.compat.v1.variable_scope('conv4') as scope:
  53. kernel = _weight_variable('weights', shape=[3,3,64,64])
  54. biases = _bias_variable('biases',[64])
  55. pre_activation = tf.compat.v1.nn.bias_add(_conv2d(pool3, kernel),biases)
  56. conv4 = tf.compat.v1.nn.relu(pre_activation, name=scope.name)
  57. pool4 = _max_pool_2x2(conv4, name='pool4')
  58. with tf.compat.v1.variable_scope('local1') as scope:
  59. batch_size = images.get_shape()[0]
  60. reshape = tf.compat.v1.reshape(pool4, [batch_size,-1])
  61. dim = reshape.get_shape()[1]
  62. weights = _weight_variable('weights', shape=[dim,1024])
  63. biases = _bias_variable('biases',[1024])
  64. local1 = tf.compat.v1.nn.relu(tf.compat.v1.matmul(reshape,weights) + biases, name=scope.name)
  65. local1_drop = tf.compat.v1.nn.dropout(local1, keep_prob)
  66. with tf.compat.v1.variable_scope('softmax_linear') as scope:
  67. weights = _weight_variable('weights',shape=[1024,CHARS_NUM*CLASSES_NUM])
  68. biases = _bias_variable('biases',[CHARS_NUM*CLASSES_NUM])
  69. softmax_linear = tf.compat.v1.add(tf.compat.v1.matmul(local1_drop,weights), biases, name=scope.name)
  70. return tf.compat.v1.reshape(softmax_linear, [-1, CHARS_NUM, CLASSES_NUM])
  71. def loss(logits, labels):
  72. cross_entropy = tf.compat.v1.nn.softmax_cross_entropy_with_logits(
  73. labels=labels, logits=logits, name='corss_entropy_per_example')
  74. cross_entropy_mean = tf.compat.v1.reduce_mean(cross_entropy, name='cross_entropy')
  75. tf.compat.v1.add_to_collection('losses', cross_entropy_mean)
  76. return tf.compat.v1.add_n(tf.compat.v1.get_collection('losses'), name='total_loss')
  77. def training(loss):
  78. optimizer = tf.compat.v1.train.AdamOptimizer(1e-4)
  79. train_op = optimizer.minimize(loss)
  80. return train_op
  81. def evaluation(logits, labels):
  82. correct_prediction = tf.compat.v1.equal(tf.compat.v1.argmax(logits,2), tf.compat.v1.argmax(labels,2))
  83. correct_batch = tf.compat.v1.reduce_mean(tf.compat.v1.cast(correct_prediction, tf.compat.v1.int32), 1)
  84. return tf.compat.v1.reduce_sum(tf.compat.v1.cast(correct_batch, tf.compat.v1.float32))
  85. def output(logits):
  86. return tf.compat.v1.argmax(logits, 2)