From 23f058dae3c4442565fa070e4bc0751f96f97c9b Mon Sep 17 00:00:00 2001 From: Yann SOULLARD <yann.soullard@univ-rouen.fr> Date: Thu, 1 Feb 2018 11:56:40 +0100 Subject: [PATCH] maj error 999 --- CTCModel.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/CTCModel.py b/CTCModel.py index ed82020..d70cfea 100644 --- a/CTCModel.py +++ b/CTCModel.py @@ -511,7 +511,7 @@ class CTCModel: use_multiprocessing=use_multiprocessing, verbose=verbose) if 'ser' in metrics: - seq_error = float(np.sum([1 for ler_data in ler_dataset if ler_data != 0])) / len(ler_dataset) if len(ler_dataset)>0 else 0. + seq_error = float(np.sum([1 for ler_data in ler_dataset if ler_data != 0])) / len(ler_dataset) if len(ler_dataset)>0 else 1. outmetrics = [] if 'loss' in metrics: @@ -738,7 +738,7 @@ class CTCModel: out = self._predict_loop(f, ins, batch_size=batch_size, max_value=max_value, verbose=verbose, steps=steps, max_len=max_len) - out_decode = [dec_data[:list(dec_data).index(max_value)] for i,dec_data in enumerate(out)] + out_decode = [dec_data[:list(dec_data).index(max_value)] if max_value in dec_data else dec_data for i,dec_data in enumerate(out)] return out_decode def _predict_loop(self, f, ins, max_len=100, max_value=999, batch_size=32, verbose=0, steps=None): @@ -977,14 +977,15 @@ class CTCModel: self.compile(optimizer) - if os.path.exists(file_weights): - self.model_train.load_weights(file_weights) - self.model_pred.set_weights(self.model_train.get_weights()) - self.model_eval.set_weights(self.model_train.get_weights()) - elif os.path.exists(path_dir + file_weights): - self.model_train.load_weights(path_dir + file_weights) - self.model_pred.set_weights(self.model_train.get_weights()) - self.model_eval.set_weights(self.model_train.get_weights()) + if file_weights is not None: + if os.path.exists(file_weights): + self.model_train.load_weights(file_weights) + self.model_pred.set_weights(self.model_train.get_weights()) + self.model_eval.set_weights(self.model_train.get_weights()) + elif os.path.exists(path_dir + file_weights): + self.model_train.load_weights(path_dir + file_weights) + self.model_pred.set_weights(self.model_train.get_weights()) + self.model_eval.set_weights(self.model_train.get_weights()) -- GitLab