diff --git a/CTCModel.py b/CTCModel.py
index d70cfea95786dfb9feeccdd99e0a82badfae2feb..7c8ac18625198728e3c4cb420b21584d39a04848 100644
--- a/CTCModel.py
+++ b/CTCModel.py
@@ -760,10 +760,12 @@ class CTCModel:
or list of arrays of predictions
(if the model has multiple outputs).
"""
- num_samples = self.model_pred._check_num_samples(ins, batch_size,
- steps,
- 'steps')
+ num_samples = check_num_samples(ins,
+ batch_size=batch_size,
+ steps=steps,
+ steps_name='steps')
+
if steps is not None:
# Step-based predictions.
# Since we do not know how many samples
@@ -1182,4 +1184,50 @@ def tf_edit_distance(hypothesis, truth, norm=False):
inputs are tf.Sparse_tensors """
- return tf.edit_distance(hypothesis, truth, normalize=norm, name='edit_distance')
\ No newline at end of file
+ return tf.edit_distance(hypothesis, truth, normalize=norm, name='edit_distance')
+
+
+def check_num_samples(ins,
+ batch_size=None,
+ steps=None,
+ steps_name='steps'):
+ """Checks the number of samples provided for training and evaluation.
+ The number of samples is not defined when running with `steps`,
+ in which case the number of samples is set to `None`.
+ # Arguments
+ ins: List of tensors to be fed to the Keras function.
+ batch_size: Integer batch size or `None` if not defined.
+ steps: Total number of steps (batches of samples)
+ before declaring `predict_loop` finished.
+ Ignored with the default value of `None`.
+ steps_name: The public API's parameter name for `steps`.
+ # Raises
+ ValueError: when `steps` is `None` and the attribute `ins.shape`
+ does not exist. Also raises ValueError when `steps` is not `None`
+ and `batch_size` is not `None` because they are mutually
+ exclusive.
+ # Returns
+ When `steps` is `None`, returns the number of samples to be
+ processed based on the size of the first dimension of the
+ first input Numpy array. When `steps` is not `None` and
+ `batch_size` is `None`, returns `None`.
+ # Raises
+ ValueError: In case of invalid arguments.
+ """
+ if steps is not None and batch_size is not None:
+ raise ValueError(
+ 'If ' + steps_name + ' is set, the `batch_size` must be None.')
+
+ if not ins or any(K.is_tensor(x) for x in ins):
+ if steps is None:
+ raise ValueError(
+ 'If your data is in the form of symbolic tensors, '
+ 'you should specify the `' + steps_name + '` argument '
+ '(instead of the `batch_size` argument, '
+ 'because symbolic tensors are expected to produce '
+ 'batches of input data).')
+ return None
+
+ if hasattr(ins[0], 'shape'):
+ return int(ins[0].shape[0])
+ return None # Edge case where ins == [static_learning_phase]