From 23f058dae3c4442565fa070e4bc0751f96f97c9b Mon Sep 17 00:00:00 2001
From: Yann SOULLARD <yann.soullard@univ-rouen.fr>
Date: Thu, 1 Feb 2018 11:56:40 +0100
Subject: [PATCH] maj error 999

---
 CTCModel.py | 21 +++++++++++----------
 1 file changed, 11 insertions(+), 10 deletions(-)

diff --git a/CTCModel.py b/CTCModel.py
index ed82020..d70cfea 100644
--- a/CTCModel.py
+++ b/CTCModel.py
@@ -511,7 +511,7 @@ class CTCModel:
                           use_multiprocessing=use_multiprocessing,
                           verbose=verbose)
         if 'ser' in metrics:
-            seq_error = float(np.sum([1 for ler_data in ler_dataset if ler_data != 0])) / len(ler_dataset) if len(ler_dataset)>0 else 0.
+            seq_error = float(np.sum([1 for ler_data in ler_dataset if ler_data != 0])) / len(ler_dataset) if len(ler_dataset)>0 else 1.
 
         outmetrics = []
         if 'loss' in metrics:
@@ -738,7 +738,7 @@ class CTCModel:
         out = self._predict_loop(f, ins, batch_size=batch_size, max_value=max_value,
                                   verbose=verbose, steps=steps, max_len=max_len)
 
-        out_decode = [dec_data[:list(dec_data).index(max_value)] for i,dec_data in enumerate(out)]
+        out_decode = [dec_data[:list(dec_data).index(max_value)] if max_value in dec_data else dec_data for i,dec_data in enumerate(out)]
         return out_decode
 
     def _predict_loop(self, f, ins, max_len=100, max_value=999, batch_size=32, verbose=0, steps=None):
@@ -977,14 +977,15 @@ class CTCModel:
 
         self.compile(optimizer)
 
-        if os.path.exists(file_weights):
-            self.model_train.load_weights(file_weights)
-            self.model_pred.set_weights(self.model_train.get_weights())
-            self.model_eval.set_weights(self.model_train.get_weights())
-        elif os.path.exists(path_dir + file_weights):
-            self.model_train.load_weights(path_dir + file_weights)
-            self.model_pred.set_weights(self.model_train.get_weights())
-            self.model_eval.set_weights(self.model_train.get_weights())
+        if file_weights is not None:
+            if os.path.exists(file_weights):
+                self.model_train.load_weights(file_weights)
+                self.model_pred.set_weights(self.model_train.get_weights())
+                self.model_eval.set_weights(self.model_train.get_weights())
+            elif os.path.exists(path_dir + file_weights):
+                self.model_train.load_weights(path_dir + file_weights)
+                self.model_pred.set_weights(self.model_train.get_weights())
+                self.model_eval.set_weights(self.model_train.get_weights())
 
 
 
-- 
GitLab