Commit 86ec9bfe authored by Yann SOULLARD's avatar Yann SOULLARD

update for last keras version

parent 095b18f3
......@@ -3,21 +3,25 @@ import tensorflow as tf
import numpy as np
import os
#import warpctc_tensorflow
from keras import Input
from keras.engine import Model
from keras.layers import Lambda
from keras.models import model_from_json
from keras.models import model_from_json, Sequential
import pickle
from tensorflow.python.ops import ctc_ops as ctc
from keras.utils import Sequence, GeneratorEnqueuer, OrderedEnqueuer
import warnings
from keras.utils.generic_utils import Progbar
from keras.layers import TimeDistributed, Activation, Dense
#from ocr_ctc.utils.utils_analysis import tf_edit_distance
#from ocr_ctc.utils.utils_keras import Kreshape_To1D
from keras.preprocessing import sequence
"""
authors: Yann Soullard, Cyprien Ruffino (2017)
LITIS lab, university of Rouen (France)
......@@ -50,8 +54,14 @@ class CTCModel:
self.model_train = None
self.model_pred = None
self.model_eval = None
self.inputs = inputs
self.outputs = outputs
if not isinstance(inputs, list):
self.inputs = [inputs]
else:
self.inputs = inputs
if not isinstance(outputs, list):
self.outputs = [outputs]
else:
self.outputs = outputs
self.greedy = greedy
self.beam_width = beam_width
......@@ -59,6 +69,7 @@ class CTCModel:
self.charset = charset
def compile(self, optimizer):
"""
Configures the CTC Model for training.
......@@ -760,12 +771,14 @@ class CTCModel:
or list of arrays of predictions
(if the model has multiple outputs).
"""
# num_samples = self.model_pred._check_num_samples(ins, batch_size,
# steps,
# 'steps')
num_samples = check_num_samples(ins,
batch_size=batch_size,
steps=steps,
steps_name='steps')
if steps is not None:
# Step-based predictions.
# Since we do not know how many samples
......@@ -828,7 +841,7 @@ class CTCModel:
y_pred, labels, input_length, label_length = args
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)#, ignore_longer_outputs_than_inputs=True)
# return warpctc_tensorflow.ctc(y_pred, labels, label_length, input_length)
@staticmethod
def ctc_complete_decoding_lambda_func(args, **arguments):
......@@ -931,7 +944,7 @@ class CTCModel:
output.close()
def load_model(self, path_dir, optimizer, file_weights=None):
def load_model(self, path_dir, optimizer, file_weights=None, change_parameters=False, init_last_layer=False):
""" Load a model in path_dir
load model_train, model_pred and model_eval from json
load inputs and outputs from json
......@@ -972,10 +985,11 @@ class CTCModel:
param = p.load()
input.close()
self.greedy = param['greedy'] if 'greedy' in param.keys() else self.greedy
self.beam_width = param['beam_width'] if 'beam_width' in param.keys() else self.beam_width
self.top_paths = param['top_paths'] if 'top_paths' in param.keys() else self.top_paths
self.charset = param['charset'] if 'charset' in param.keys() else self.charset
if not change_parameters:
self.greedy = param['greedy'] if 'greedy' in param.keys() else self.greedy
self.beam_width = param['beam_width'] if 'beam_width' in param.keys() else self.beam_width
self.top_paths = param['top_paths'] if 'top_paths' in param.keys() else self.top_paths
self.charset = param['charset'] if 'charset' in param.keys() and self.charset is None else self.charset
self.compile(optimizer)
......@@ -989,6 +1003,65 @@ class CTCModel:
self.model_pred.set_weights(self.model_train.get_weights())
self.model_eval.set_weights(self.model_train.get_weights())
if init_last_layer:
labels = Input(name='labels', shape=[None])
input_length = Input(name='input_length', shape=[1])
label_length = Input(name='label_length', shape=[1])
# new_model_init = Sequential() # new model
# for layer in self.model_init.layers[:-2]:
# new_model_init.add(layer)
# new_model_init.add(TimeDistributed(Dense(len(self.charset) + 1), name="DenseSoftmax"))
# new_model_init.add(Activation('softmax', name='Softmax'))
# self.model_init = new_model_init
new_layer = Input(name='input', shape=self.model_init.layers[0].output_shape[1:])
self.inputs = [new_layer]
for layer in self.model_init.layers[1:-2]:
new_layer = layer(new_layer)
new_layer = TimeDistributed(Dense(len(self.charset) + 1), name="DenseSoftmax")(new_layer)
new_layer = Activation('softmax', name='Softmax')(new_layer)
self.outputs = [new_layer]
# new_model_train = Sequential() # new model
# nb_layers = len(self.model_train.layers)
# new_layer = Input(name='input',
# shape=self.model_train.layers[0].output_shape[1:])
# for layer in self.model_train.layers[1:-6]:
# new_layer = layer(new_layer)
# new_layer = TimeDistributed(Dense(len(self.charset) + 1), name="DenseSoftmax")(new_layer)
# new_layer = Activation('softmax', name='Softmax')(new_layer)
# Lambda layer for computing the loss function
loss_out = Lambda(self.ctc_loss_lambda_func, output_shape=(1,), name='CTCloss')(
self.outputs + [labels, input_length, label_length])
# Lambda layer for the decoding function
out_decoded_dense = Lambda(self.ctc_complete_decoding_lambda_func, output_shape=(None, None),
name='CTCdecode', arguments={'greedy': self.greedy,
'beam_width': self.beam_width,
'top_paths': self.top_paths}, dtype="float32")(
self.outputs + [input_length])
# Lambda layer to perform an analysis (CER and SER)
out_analysis = Lambda(self.ctc_complete_analysis_lambda_func, output_shape=(None,), name='CTCanalysis',
arguments={'greedy': self.greedy,
'beam_width': self.beam_width, 'top_paths': self.top_paths},
dtype="float32")(
self.outputs + [labels, input_length, label_length])
# create Keras models
self.model_init = Model(inputs=self.inputs, outputs=self.outputs)
self.model_train = Model(inputs=self.inputs + [labels, input_length, label_length], outputs=loss_out)
self.model_pred = Model(inputs=self.inputs + [input_length], outputs=out_decoded_dense)
self.model_eval = Model(inputs=self.inputs + [labels, input_length, label_length], outputs=out_analysis)
# Compile models
self.model_train.compile(loss={'CTCloss': lambda yt, yp: yp}, optimizer=optimizer)
self.model_pred.compile(loss={'CTCdecode': lambda yt, yp: yp}, optimizer=optimizer)
self.model_eval.compile(loss={'CTCanalysis': lambda yt, yp: yp}, optimizer=optimizer)
def _standardize_input_data(data, names, shapes=None,
......@@ -1230,4 +1303,4 @@ def check_num_samples(ins,
if hasattr(ins[0], 'shape'):
return int(ins[0].shape[0])
return None # Edge case where ins == [static_learning_phase]
return None # Edge case where ins == [static_learning_phase]
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment