diff --git a/examples/vision/captcha_ocr.py b/examples/vision/captcha_ocr.py index a6bac599ff..90e9ba214b 100644 --- a/examples/vision/captcha_ocr.py +++ b/examples/vision/captcha_ocr.py @@ -359,30 +359,6 @@ def build_model(): """ -def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): - input_shape = ops.shape(y_pred) - num_samples, num_steps = input_shape[0], input_shape[1] - y_pred = ops.log(ops.transpose(y_pred, axes=[1, 0, 2]) + keras.backend.epsilon()) - input_length = ops.cast(input_length, dtype="int32") - - if greedy: - (decoded, log_prob) = tf.nn.ctc_greedy_decoder( - inputs=y_pred, sequence_length=input_length - ) - else: - (decoded, log_prob) = tf.compat.v1.nn.ctc_beam_search_decoder( - inputs=y_pred, - sequence_length=input_length, - beam_width=beam_width, - top_paths=top_paths, - ) - decoded_dense = [] - for st in decoded: - st = tf.SparseTensor(st.indices, st.values, (num_samples, num_steps)) - decoded_dense.append(tf.sparse.to_dense(sp_input=st, default_value=-1)) - return (decoded_dense, log_prob) - - # Get the prediction model by extracting layers till the output layer prediction_model = keras.models.Model( model.input[0], model.get_layer(name="dense2").output @@ -394,12 +370,14 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1): def decode_batch_predictions(pred): input_len = np.ones(pred.shape[0]) * pred.shape[1] # Use greedy search. For complex tasks, you can use beam search - results = ctc_decode(pred, input_length=input_len, greedy=True)[0][0][ - :, :max_length - ] + results = ops.ctc_decode(pred, sequence_lengths=input_len, strategy="greedy")[0][0] + # Convert the SparseTensor to a dense tensor + dense_results = tf.sparse.to_dense(results, default_value=-1) + # Slice the dense tensor to keep only up to max_length + dense_results = dense_results[:, :max_length] # Iterate over the results and get back the text output_text = [] - for res in results: + for res in dense_results: res = tf.strings.reduce_join(num_to_char(res)).numpy().decode("utf-8") output_text.append(res) return output_text