Skip to content
GitLab
Projects
Groups
Snippets
Help
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
C
CTCModel
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Service Desk
Milestones
Operations
Operations
Incidents
Analytics
Analytics
Repository
Value Stream
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Commits
Issue Boards
Open sidebar
TextRecognition
CTCModel
Commits
a52e8d57
Commit
a52e8d57
authored
Jan 09, 2018
by
Yann SOULLARD
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
MAJ CTCModel
parent
14171694
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
1 deletion
+32
-1
CTCModel.py
CTCModel.py
+32
-1
No files found.
CTCModel.py
View file @
a52e8d57
...
...
@@ -128,6 +128,28 @@ class CTCModel:
"""
return
self
.
model_eval
def
get_loss_on_batch
(
self
,
inputs
,
verbose
=
False
):
"""
Computation the loss
inputs is a list of 4 elements:
x_features, y_label, x_len, y_len (similarly to the CTC in tensorflow)
:return: Probabilities (output of the TimeDistributedDense layer)
"""
x
=
inputs
[
0
]
x_len
=
inputs
[
2
]
y
=
inputs
[
1
]
y_len
=
inputs
[
3
]
no_lab
=
True
if
0
in
y_len
else
False
if
no_lab
is
False
:
loss_data
=
self
.
model_train
.
predict_on_batch
([
x
,
y
,
x_len
,
y_len
],
verbose
=
verbose
)
loss
=
np
.
sum
(
loss_data
)
return
loss
,
loss_data
def
get_loss
(
self
,
inputs
,
verbose
=
False
):
"""
...
...
@@ -982,7 +1004,7 @@ class CTCModel:
output
.
close
()
def
load_model
(
self
,
path_dir
,
optimizer
):
def
load_model
(
self
,
path_dir
,
optimizer
,
file_weights
=
None
):
""" Load a model in path_dir
load model_train, model_pred and model_eval from json
load inputs and outputs from json
...
...
@@ -1030,6 +1052,15 @@ class CTCModel:
self
.
compile
(
optimizer
)
if
os
.
path
.
exists
(
file_weights
):
self
.
model_train
.
load_weights
(
file_weights
)
self
.
model_pred
.
set_weights
(
self
.
model_train
.
get_weights
())
self
.
model_eval
.
set_weights
(
self
.
model_train
.
get_weights
())
elif
os
.
path
.
exists
(
path_dir
+
file_weights
):
self
.
model_train
.
load_weights
(
path_dir
+
file_weights
)
self
.
model_pred
.
set_weights
(
self
.
model_train
.
get_weights
())
self
.
model_eval
.
set_weights
(
self
.
model_train
.
get_weights
())
def
_standardize_input_data
(
data
,
names
,
shapes
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment