Commit a52e8d57 authored by Yann SOULLARD's avatar Yann SOULLARD

MAJ CTCModel

parent 14171694
......@@ -128,6 +128,28 @@ class CTCModel:
"""
return self.model_eval
def get_loss_on_batch(self, inputs, verbose=False):
"""
Computation the loss
inputs is a list of 4 elements:
x_features, y_label, x_len, y_len (similarly to the CTC in tensorflow)
:return: Probabilities (output of the TimeDistributedDense layer)
"""
x = inputs[0]
x_len = inputs[2]
y = inputs[1]
y_len = inputs[3]
no_lab = True if 0 in y_len else False
if no_lab is False:
loss_data = self.model_train.predict_on_batch([x, y, x_len, y_len], verbose=verbose)
loss = np.sum(loss_data)
return loss, loss_data
def get_loss(self, inputs, verbose=False):
"""
......@@ -982,7 +1004,7 @@ class CTCModel:
output.close()
def load_model(self, path_dir, optimizer):
def load_model(self, path_dir, optimizer, file_weights=None):
""" Load a model in path_dir
load model_train, model_pred and model_eval from json
load inputs and outputs from json
......@@ -1030,6 +1052,15 @@ class CTCModel:
self.compile(optimizer)
if os.path.exists(file_weights):
self.model_train.load_weights(file_weights)
self.model_pred.set_weights(self.model_train.get_weights())
self.model_eval.set_weights(self.model_train.get_weights())
elif os.path.exists(path_dir + file_weights):
self.model_train.load_weights(path_dir + file_weights)
self.model_pred.set_weights(self.model_train.get_weights())
self.model_eval.set_weights(self.model_train.get_weights())
def _standardize_input_data(data, names, shapes=None,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment