Commit a52e8d57 authored by Yann SOULLARD's avatar Yann SOULLARD


parent 14171694
...@@ -128,6 +128,28 @@ class CTCModel: ...@@ -128,6 +128,28 @@ class CTCModel:
""" """
return self.model_eval return self.model_eval
def get_loss_on_batch(self, inputs, verbose=False):
Computation the loss
inputs is a list of 4 elements:
x_features, y_label, x_len, y_len (similarly to the CTC in tensorflow)
:return: Probabilities (output of the TimeDistributedDense layer)
x = inputs[0]
x_len = inputs[2]
y = inputs[1]
y_len = inputs[3]
no_lab = True if 0 in y_len else False
if no_lab is False:
loss_data = self.model_train.predict_on_batch([x, y, x_len, y_len], verbose=verbose)
loss = np.sum(loss_data)
return loss, loss_data
def get_loss(self, inputs, verbose=False): def get_loss(self, inputs, verbose=False):
""" """
...@@ -982,7 +1004,7 @@ class CTCModel: ...@@ -982,7 +1004,7 @@ class CTCModel:
output.close() output.close()
def load_model(self, path_dir, optimizer): def load_model(self, path_dir, optimizer, file_weights=None):
""" Load a model in path_dir """ Load a model in path_dir
load model_train, model_pred and model_eval from json load model_train, model_pred and model_eval from json
load inputs and outputs from json load inputs and outputs from json
...@@ -1030,6 +1052,15 @@ class CTCModel: ...@@ -1030,6 +1052,15 @@ class CTCModel:
self.compile(optimizer) self.compile(optimizer)
if os.path.exists(file_weights):
elif os.path.exists(path_dir + file_weights):
self.model_train.load_weights(path_dir + file_weights)
def _standardize_input_data(data, names, shapes=None, def _standardize_input_data(data, names, shapes=None,
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment