Commit 23f058da authored by Yann SOULLARD's avatar Yann SOULLARD

maj error 999

parent c2524a00
...@@ -511,7 +511,7 @@ class CTCModel: ...@@ -511,7 +511,7 @@ class CTCModel:
use_multiprocessing=use_multiprocessing, use_multiprocessing=use_multiprocessing,
verbose=verbose) verbose=verbose)
if 'ser' in metrics: if 'ser' in metrics:
seq_error = float(np.sum([1 for ler_data in ler_dataset if ler_data != 0])) / len(ler_dataset) if len(ler_dataset)>0 else 0. seq_error = float(np.sum([1 for ler_data in ler_dataset if ler_data != 0])) / len(ler_dataset) if len(ler_dataset)>0 else 1.
outmetrics = [] outmetrics = []
if 'loss' in metrics: if 'loss' in metrics:
...@@ -738,7 +738,7 @@ class CTCModel: ...@@ -738,7 +738,7 @@ class CTCModel:
out = self._predict_loop(f, ins, batch_size=batch_size, max_value=max_value, out = self._predict_loop(f, ins, batch_size=batch_size, max_value=max_value,
verbose=verbose, steps=steps, max_len=max_len) verbose=verbose, steps=steps, max_len=max_len)
out_decode = [dec_data[:list(dec_data).index(max_value)] for i,dec_data in enumerate(out)] out_decode = [dec_data[:list(dec_data).index(max_value)] if max_value in dec_data else dec_data for i,dec_data in enumerate(out)]
return out_decode return out_decode
def _predict_loop(self, f, ins, max_len=100, max_value=999, batch_size=32, verbose=0, steps=None): def _predict_loop(self, f, ins, max_len=100, max_value=999, batch_size=32, verbose=0, steps=None):
...@@ -977,14 +977,15 @@ class CTCModel: ...@@ -977,14 +977,15 @@ class CTCModel:
self.compile(optimizer) self.compile(optimizer)
if os.path.exists(file_weights): if file_weights is not None:
self.model_train.load_weights(file_weights) if os.path.exists(file_weights):
self.model_pred.set_weights(self.model_train.get_weights()) self.model_train.load_weights(file_weights)
self.model_eval.set_weights(self.model_train.get_weights()) self.model_pred.set_weights(self.model_train.get_weights())
elif os.path.exists(path_dir + file_weights): self.model_eval.set_weights(self.model_train.get_weights())
self.model_train.load_weights(path_dir + file_weights) elif os.path.exists(path_dir + file_weights):
self.model_pred.set_weights(self.model_train.get_weights()) self.model_train.load_weights(path_dir + file_weights)
self.model_eval.set_weights(self.model_train.get_weights()) self.model_pred.set_weights(self.model_train.get_weights())
self.model_eval.set_weights(self.model_train.get_weights())
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment