diff --git a/CTCModel.py b/CTCModel.py index d70cfea95786dfb9feeccdd99e0a82badfae2feb..7c8ac18625198728e3c4cb420b21584d39a04848 100644 --- a/CTCModel.py +++ b/CTCModel.py @@ -760,10 +760,12 @@ class CTCModel: or list of arrays of predictions (if the model has multiple outputs). """ - num_samples = self.model_pred._check_num_samples(ins, batch_size, - steps, - 'steps') + num_samples = check_num_samples(ins, + batch_size=batch_size, + steps=steps, + steps_name='steps') + if steps is not None: # Step-based predictions. # Since we do not know how many samples @@ -1182,4 +1184,50 @@ def tf_edit_distance(hypothesis, truth, norm=False): inputs are tf.Sparse_tensors """ - return tf.edit_distance(hypothesis, truth, normalize=norm, name='edit_distance') \ No newline at end of file + return tf.edit_distance(hypothesis, truth, normalize=norm, name='edit_distance') + + +def check_num_samples(ins, + batch_size=None, + steps=None, + steps_name='steps'): + """Checks the number of samples provided for training and evaluation. + The number of samples is not defined when running with `steps`, + in which case the number of samples is set to `None`. + # Arguments + ins: List of tensors to be fed to the Keras function. + batch_size: Integer batch size or `None` if not defined. + steps: Total number of steps (batches of samples) + before declaring `predict_loop` finished. + Ignored with the default value of `None`. + steps_name: The public API's parameter name for `steps`. + # Raises + ValueError: when `steps` is `None` and the attribute `ins.shape` + does not exist. Also raises ValueError when `steps` is not `None` + and `batch_size` is not `None` because they are mutually + exclusive. + # Returns + When `steps` is `None`, returns the number of samples to be + processed based on the size of the first dimension of the + first input Numpy array. When `steps` is not `None` and + `batch_size` is `None`, returns `None`. + # Raises + ValueError: In case of invalid arguments. + """ + if steps is not None and batch_size is not None: + raise ValueError( + 'If ' + steps_name + ' is set, the `batch_size` must be None.') + + if not ins or any(K.is_tensor(x) for x in ins): + if steps is None: + raise ValueError( + 'If your data is in the form of symbolic tensors, ' + 'you should specify the `' + steps_name + '` argument ' + '(instead of the `batch_size` argument, ' + 'because symbolic tensors are expected to produce ' + 'batches of input data).') + return None + + if hasattr(ins[0], 'shape'): + return int(ins[0].shape[0]) + return None # Edge case where ins == [static_learning_phase]