|
| 1 | +#!/usr/bin/env python |
| 2 | +# encoding: utf-8 |
| 3 | +# author: Lock |
| 4 | +# time: 2018/3/18 17:26 |
| 5 | + |
| 6 | +import tensorflow as tf |
| 7 | +from train import cnn_graph |
| 8 | +from train import get_random_captcha_text_and_image |
| 9 | +from train import vec2text, convert2gray |
| 10 | +from create_captcha_img import CAPTCHA_LIST, CAPTCHA_WIDTH, CAPTCHA_HEIGHT, CAPTCHA_LEN |
| 11 | + |
| 12 | + |
| 13 | +def captcha_to_text(image_list, height=CAPTCHA_HEIGHT, width=CAPTCHA_WIDTH): |
| 14 | + ''' |
| 15 | + 验证码图片转化为文本 |
| 16 | + :param image_list: |
| 17 | + :param height: |
| 18 | + :param width: |
| 19 | + :return: |
| 20 | + ''' |
| 21 | + x = tf.placeholder(tf.float32, [None, height * width]) |
| 22 | + keep_prob = tf.placeholder(tf.float32) |
| 23 | + y_conv = cnn_graph(x, keep_prob, (height, width)) |
| 24 | + saver = tf.train.Saver() |
| 25 | + with tf.Session() as sess: |
| 26 | + saver.restore(sess, tf.train.latest_checkpoint('.')) |
| 27 | + predict = tf.argmax(tf.reshape(y_conv, [-1, CAPTCHA_LEN, len(CAPTCHA_LIST)]), 2) |
| 28 | + vector_list = sess.run(predict, feed_dict={x: image_list, keep_prob: 1}) |
| 29 | + vector_list = vector_list.tolist() |
| 30 | + text_list = [vec2text(vector) for vector in vector_list] |
| 31 | + return text_list[0] |
| 32 | + |
| 33 | + |
| 34 | +def multi_test(height=CAPTCHA_HEIGHT, width=CAPTCHA_WIDTH): |
| 35 | + x = tf.placeholder(tf.float32, [None, height * width]) |
| 36 | + keep_prob = tf.placeholder(tf.float32) |
| 37 | + y_conv = cnn_graph(x, keep_prob, (height, width)) |
| 38 | + saver = tf.train.Saver() |
| 39 | + with tf.Session() as sess: |
| 40 | + saver.restore(sess, tf.train.latest_checkpoint('.')) |
| 41 | + while 1: |
| 42 | + text, image = get_random_captcha_text_and_image() |
| 43 | + image = convert2gray(image) |
| 44 | + image = image.flatten() / 255 |
| 45 | + image_list = [image] |
| 46 | + predict = tf.argmax(tf.reshape(y_conv, [-1, CAPTCHA_LEN, len(CAPTCHA_LIST)]), 2) |
| 47 | + vector_list = sess.run(predict, feed_dict={x: image_list, keep_prob: 1}) |
| 48 | + vector_list = vector_list.tolist() |
| 49 | + text_list = [vec2text(vector) for vector in vector_list] |
| 50 | + pre_text = text_list[0] |
| 51 | + flag = u'错误' |
| 52 | + if text == pre_text: |
| 53 | + flag = u'正确' |
| 54 | + print u"实际值(actual):%s, 预测值(predict):%s, 预测结果:%s" % (text, pre_text, flag,) |
| 55 | + |
| 56 | + |
| 57 | +if __name__ == '__main__': |
| 58 | + try: |
| 59 | + # 多个测试 |
| 60 | + multi_test() |
| 61 | + exit() |
| 62 | + |
| 63 | + text, image = get_random_captcha_text_and_image() |
| 64 | + image = convert2gray(image) |
| 65 | + image = image.flatten() / 255 |
| 66 | + pre_text = captcha_to_text([image]) |
| 67 | + flag = u'错误' |
| 68 | + if text == pre_text: |
| 69 | + flag = u'正确' |
| 70 | + print u"实际值(actual):%s, 预测值(predict):%s, 预测结果:%s" % (text, pre_text, flag,) |
| 71 | + except KeyboardInterrupt as e: |
| 72 | + print e.message |
0 commit comments