diff --git a/examples/streaming_conformer/train.py b/examples/streaming_conformer/train.py new file mode 100644 index 0000000000..94a6297117 --- /dev/null +++ b/examples/streaming_conformer/train.py @@ -0,0 +1,156 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import argparse +from tensorflow_asr.utils import env_util + +env_util.setup_environment() +import tensorflow as tf +physical_devices = tf.config.list_physical_devices('GPU') +tf.config.experimental.set_memory_growth(physical_devices[0], True) + +DEFAULT_YAML = os.path.join(os.path.abspath(os.path.dirname(__file__)), "config.yml") + +tf.keras.backend.clear_session() + +parser = argparse.ArgumentParser(prog="Conformer Training") + +parser.add_argument("--config", type=str, default=DEFAULT_YAML, help="The file path of model configuration file") + +parser.add_argument("--max_ckpts", type=int, default=10, help="Max number of checkpoints to keep") + +parser.add_argument("--tfrecords", default=False, action="store_true", help="Whether to use tfrecords") + +parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model") + +parser.add_argument("--subwords", default=False, action="store_true", help="Use subwords") + +parser.add_argument("--tbs", type=int, default=None, help="Train batch size per replica") + +parser.add_argument("--ebs", type=int, default=None, help="Evaluation batch size per replica") + +parser.add_argument("--spx", type=int, default=1, help="Steps per execution for maximizing performance") + +parser.add_argument("--metadata", type=str, default=None, help="Path to file containing metadata") + +parser.add_argument("--static_length", default=False, action="store_true", help="Use static lengths") + +parser.add_argument("--devices", type=int, nargs="*", default=[0], help="Devices' ids to apply distributed training") + +parser.add_argument("--mxp", default=False, action="store_true", help="Enable mixed precision") + +args = parser.parse_args() + +tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp}) + +strategy = env_util.setup_strategy(args.devices) + +from tensorflow_asr.configs.config import Config +from tensorflow_asr.datasets.asr_dataset import ASRMaskedSliceDataset, ASRMaskedTFRecordDataset +from tensorflow_asr.featurizers import speech_featurizers, text_featurizers +from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer +from tensorflow_asr.models.transducer.streaming_conformer import StreamingConformer +from tensorflow_asr.optimizers.schedules import TransformerSchedule + +config = Config(args.config) +speech_featurizer = speech_featurizers.TFSpeechFeaturizer(config.speech_config) + +if args.sentence_piece: + print("Loading SentencePiece model ...") + text_featurizer = text_featurizers.SentencePieceFeaturizer(config.decoder_config) +elif args.subwords: + print("Loading subwords ...") + text_featurizer = text_featurizers.SubwordFeaturizer(config.decoder_config) +else: + print("Use characters ...") + text_featurizer = text_featurizers.CharFeaturizer(config.decoder_config) + +time_reduction_factor = config.model_config['encoder_subsampling']['strides'] * 2 +if args.tfrecords: + train_dataset = ASRMaskedTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.train_dataset_config) + ) + eval_dataset = ASRMaskedTFRecordDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + **vars(config.learning_config.eval_dataset_config) + ) +else: + train_dataset = ASRMaskedSliceDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + time_reduction_factor=time_reduction_factor, + **vars(config.learning_config.train_dataset_config) + ) + eval_dataset = ASRMaskedSliceDataset( + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + time_reduction_factor=time_reduction_factor, + **vars(config.learning_config.eval_dataset_config) + ) + +train_dataset.load_metadata(args.metadata) +eval_dataset.load_metadata(args.metadata) + +if not args.static_length: + speech_featurizer.reset_length() + text_featurizer.reset_length() + +global_batch_size = args.tbs or config.learning_config.running_config.batch_size +global_batch_size *= strategy.num_replicas_in_sync + +global_eval_batch_size = args.ebs or global_batch_size +global_eval_batch_size *= strategy.num_replicas_in_sync + +train_data_loader = train_dataset.create(global_batch_size) +eval_data_loader = eval_dataset.create(global_eval_batch_size) + +with strategy.scope(): + # build model + streaming_conformer = StreamingConformer(**config.model_config, vocabulary_size=text_featurizer.num_classes) + streaming_conformer.make(speech_featurizer.shape) + streaming_conformer.summary(line_length=150) + + optimizer = tf.keras.optimizers.Adam( + TransformerSchedule( + d_model=streaming_conformer.dmodel, + warmup_steps=config.learning_config.optimizer_config.pop("warmup_steps", 10000), + max_lr=(0.05 / math.sqrt(streaming_conformer.dmodel)) + ), + **config.learning_config.optimizer_config + ) + + streaming_conformer.compile( + optimizer=optimizer, + experimental_steps_per_execution=args.spx, + global_batch_size=global_batch_size, + blank=text_featurizer.blank + ) + +callbacks = [ + tf.keras.callbacks.ModelCheckpoint(**config.learning_config.running_config.checkpoint), + tf.keras.callbacks.experimental.BackupAndRestore(config.learning_config.running_config.states_dir), + tf.keras.callbacks.TensorBoard(**config.learning_config.running_config.tensorboard) +] + +streaming_conformer.fit( + train_data_loader, + batch_size=global_batch_size, + epochs=config.learning_config.running_config.num_epochs, + steps_per_epoch=train_dataset.total_steps, + validation_data=eval_data_loader, + validation_batch_size=global_eval_batch_size, + validation_steps=eval_dataset.total_steps, + callbacks=callbacks, +) diff --git a/tensorflow_asr/datasets/asr_dataset.py b/tensorflow_asr/datasets/asr_dataset.py index f957554cf3..c0edd02e92 100755 --- a/tensorflow_asr/datasets/asr_dataset.py +++ b/tensorflow_asr/datasets/asr_dataset.py @@ -366,3 +366,284 @@ def create(self, batch_size: int): dataset = tf.data.Dataset.from_tensor_slices(self.entries) dataset = dataset.map(self.load, num_parallel_calls=AUTOTUNE) return self.process(dataset, batch_size) + + +class ASRMaskedSliceDataset(ASRSliceDataset): + """ Dataset for ASR with rolling mask """ + + def __init__(self, + stage: str, + speech_featurizer: SpeechFeaturizer, + text_featurizer: TextFeaturizer, + data_paths: list, + augmentations: Augmentation = Augmentation(None), + cache: bool = False, + shuffle: bool = False, + indefinite: bool = False, + drop_remainder: bool = True, + use_tf: bool = False, + buffer_size: int = BUFFER_SIZE, + history_window_size: int = 3, + input_chunk_duration: int = 250, + time_reduction_factor: int = 4, + **kwargs): + super(ASRMaskedSliceDataset, self).__init__( + data_paths=data_paths, augmentations=augmentations, + cache=cache, shuffle=shuffle, stage=stage, buffer_size=buffer_size, + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite, + speech_featurizer=speech_featurizer, text_featurizer=text_featurizer + ) + self.speech_featurizer = speech_featurizer + self.text_featurizer = text_featurizer + self.history_window_size = history_window_size + self.input_chunk_size = input_chunk_duration * self.speech_featurizer.sample_rate // 1000 + self.time_reduction_factor = time_reduction_factor + + self.base_mask = tf.constant(0, dtype=tf.int32) + self.use_base_mask = False + + # If max input length is known, pre-compute mask + if self.speech_featurizer.max_length: + num_frames = 1 + (self.speech_featurizer.max_length - self.speech_featurizer.frame_length) // self.speech_featurizer.frame_step + self.base_mask = self.calculate_mask(num_frames) + self.use_base_mask = True + + def _recalculate_mask(self, num_frames): + frames_per_chunk = self.input_chunk_size // self.speech_featurizer.frame_step + + def _calculate_mask(num_frames, frames_per_chunk, history_window_size): + mask = np.zeros((num_frames, num_frames), dtype=np.int32) + for i in range(num_frames): + # Frames in the same chunk can see each other + # If frames in `history_window_size` are in other chunks, the full chunks are visible + current_chunk_index = i // frames_per_chunk + history_chunk_index = (i - history_window_size) // frames_per_chunk + for curr in range(history_chunk_index, current_chunk_index + 1): + for j in range(frames_per_chunk): + base_index = curr * frames_per_chunk + if base_index + j < 0 or base_index + j >= num_frames: + continue + mask[i, base_index + j] = 1 + return mask + + @tf.function(autograph=True) + def _calculate_mask_tf(num_frames, frames_per_chunk, history_window_size): + chunk_ids = tf.range(num_frames) // frames_per_chunk + num_chunks = tf.cast(tf.math.ceil(num_frames / frames_per_chunk), dtype=tf.int32) + + # Create first `frames_per_chunk` rows + current = tf.ones((frames_per_chunk), dtype=tf.int32) + trailing = tf.ones(((num_chunks - 1) * frames_per_chunk), dtype=tf.int32) + tmp_row = tf.concat((current, trailing), axis=0) + row = tf.slice(tmp_row, [0], [num_frames]) + mask = tf.expand_dims(row, axis=0) + + for i in range(1, frames_per_chunk): + mask = tf.concat((mask, [row]), axis=0) + + # Create the following rows + for i in range(frames_per_chunk, num_frames): + tf.autograph.experimental.set_loop_options( + shape_invariants=[(mask, tf.TensorShape([None, None]))] + ) + curr_chunk_id = chunk_ids[i] + hist_i = tf.math.maximum(i - history_window_size, 0) + hist_chunk_id = chunk_ids[hist_i] + + # Build the left-most part + leading_chunk_id = hist_chunk_id - 1 + num_leading_chunks = tf.math.maximum(leading_chunk_id + 1, 0) + leftmost_row = tf.zeros((num_leading_chunks * frames_per_chunk), dtype=tf.int32) + + # Build the current visible chunks + num_hist_chunks = curr_chunk_id - hist_chunk_id + num_visible_chunks = num_hist_chunks + 1 + curr_chunk_row = tf.ones((num_visible_chunks * frames_per_chunk), dtype=tf.int32) + + # Build the trailing 0s + num_trailing_chunks = tf.math.maximum((num_chunks - curr_chunk_id) - 1, 0) + trailing_chunk_row = tf.zeros((num_trailing_chunks * frames_per_chunk), dtype=tf.int32) + + # Merge chunks, clip to output size + tmp_row = tf.concat([leftmost_row, curr_chunk_row, trailing_chunk_row], axis=0) + row = tf.slice(tmp_row, [0], [num_frames]) + + mask = tf.concat((mask, [row]), axis=0) + return mask + + if self.use_tf: + mask = _calculate_mask_tf(num_frames, frames_per_chunk, self.history_window_size) + else: + mask = tf.numpy_function( + _calculate_mask, inp=[num_frames, frames_per_chunk, self.history_window_size], Tout=tf.int32 + ) + mask.set_shape((None, None)) + + return mask + + def calculate_mask(self, num_frames): + num_frames = math_util.get_reduced_length(num_frames, self.time_reduction_factor) + + if self.use_base_mask: + return tf.slice(self.base_mask, [0, 0], [num_frames, num_frames]) + return self._recalculate_mask(num_frames) + + def preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): + preprocessed_inputs = super(ASRMaskedSliceDataset, self).preprocess(path, audio, indices) + + input_length = preprocessed_inputs[2] + mask = self.calculate_mask(input_length) + + return (*preprocessed_inputs, mask) + + def tf_preprocess(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): + preprocessed_inputs = super(ASRMaskedSliceDataset, self).tf_preprocess(path, audio, indices) + + input_length = preprocessed_inputs[2] + mask = self.calculate_mask(input_length) + + return (*preprocessed_inputs, mask) + + # -------------------------------- CREATION ------------------------------------- + + def parse(self, path: tf.Tensor, audio: tf.Tensor, indices: tf.Tensor): + """ + Returns: + path, features, input_lengths, labels, label_lengths, pred_inp, mask + """ + if self.use_tf: data = self.tf_preprocess(path, audio, indices) + else: data = self.preprocess(path, audio, indices) + + _, features, input_length, label, label_length, prediction, prediction_length, mask = data + + return ( + data_util.create_inputs( + inputs=features, + inputs_length=input_length, + predictions=prediction, + predictions_length=prediction_length, + mask=mask + ), + data_util.create_labels( + labels=label, + labels_length=label_length + ) + ) + + def process(self, dataset: tf.data.Dataset, batch_size: int): + dataset = dataset.map(self.parse, num_parallel_calls=AUTOTUNE) + self.total_steps = math_util.get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder) + + if self.cache: + dataset = dataset.cache() + + if self.shuffle: + dataset = dataset.shuffle(self.buffer_size, reshuffle_each_iteration=True) + + if self.indefinite and self.total_steps: + dataset = dataset.repeat() + + # PADDED BATCH the dataset + dataset = dataset.padded_batch( + batch_size=batch_size, + padded_shapes=( + data_util.create_inputs( + inputs=tf.TensorShape(self.speech_featurizer.shape), + inputs_length=tf.TensorShape([]), + predictions=tf.TensorShape(self.text_featurizer.prepand_shape), + predictions_length=tf.TensorShape([]), + mask=tf.TensorShape([self.speech_featurizer.shape[0], self.speech_featurizer.shape[0]]) + ), + data_util.create_labels( + labels=tf.TensorShape(self.text_featurizer.shape), + labels_length=tf.TensorShape([]) + ), + ), + padding_values=( + data_util.create_inputs( + inputs= 0., + inputs_length=0, + predictions=self.text_featurizer.blank, + predictions_length=0, + mask=0 + ), + data_util.create_labels( + labels=self.text_featurizer.blank, + labels_length=0 + ) + ), + drop_remainder = self.drop_remainder + ) + + # PREFETCH to improve speed of input length + dataset = dataset.prefetch(AUTOTUNE) + return dataset + + def create(self, batch_size: int): + self.read_entries() + if not self.total_steps or self.total_steps == 0: return None + dataset = tf.data.Dataset.from_generator( + self.generator, + output_types=(tf.string, tf.string, tf.string), + output_shapes=(tf.TensorShape([]), tf.TensorShape([]), tf.TensorShape([])) + ) + return self.process(dataset, batch_size) + + +class ASRMaskedTFRecordDataset(ASRMaskedSliceDataset, ASRTFRecordDataset): + """ Dataset for ASR using TFRecords with rolling mask """ + def __init__(self, + data_paths: list, + tfrecords_dir: str, + speech_featurizer: SpeechFeaturizer, + text_featurizer: TextFeaturizer, + stage: str, + augmentations: Augmentation = Augmentation(None), + tfrecords_shards: int = TFRECORD_SHARDS, + cache: bool = False, + shuffle: bool = False, + use_tf: bool = False, + indefinite: bool = False, + drop_remainder: bool = True, + buffer_size: int = BUFFER_SIZE, + history_window_size: int = 3, + input_chunk_duration: int = 250, + time_reduction_factor: int = 4, + **kwargs): + super(ASRMaskedTFRecordDataset, self).__init__( + stage=stage, speech_featurizer=speech_featurizer, text_featurizer=text_featurizer, + data_paths=data_paths, augmentations=augmentations, cache=cache, shuffle=shuffle, buffer_size=buffer_size, + drop_remainder=drop_remainder, use_tf=use_tf, indefinite=indefinite, history_window_size = history_window_size, + input_chunk_duration = input_chunk_duration, time_reduction_factor = time_reduction_factor, + ) + if not self.stage: raise ValueError("stage must be defined, either 'train', 'eval' or 'test'") + self.tfrecords_dir = tfrecords_dir + if tfrecords_shards <= 0: raise ValueError("tfrecords_shards must be positive") + self.tfrecords_shards = tfrecords_shards + if not tf.io.gfile.exists(self.tfrecords_dir): tf.io.gfile.makedirs(self.tfrecords_dir) + + def parse(self, record: tf.Tensor): + feature_description = { + "path": tf.io.FixedLenFeature([], tf.string), + "audio": tf.io.FixedLenFeature([], tf.string), + "indices": tf.io.FixedLenFeature([], tf.string) + } + example = tf.io.parse_single_example(record, feature_description) + if self.use_tf: data = self.tf_preprocess(**example) + else: data = self.preprocess(**example) + + _, features, input_length, label, label_length, prediction, prediction_length, mask = data + + return ( + data_util.create_inputs( + inputs=features, + inputs_length=input_length, + predictions=prediction, + predictions_length=prediction_length, + mask=mask + ), + data_util.create_labels( + labels=label, + labels_length=label_length + ) + ) diff --git a/tensorflow_asr/models/encoders/conformer.py b/tensorflow_asr/models/encoders/conformer.py index de7b767fdd..6bc6cde65c 100644 --- a/tensorflow_asr/models/encoders/conformer.py +++ b/tensorflow_asr/models/encoders/conformer.py @@ -19,6 +19,7 @@ from ..layers.positional_encoding import PositionalEncoding, PositionalEncodingConcat from ..layers.multihead_attention import MultiHeadAttention, RelPositionMultiHeadAttention from ...utils import shape_util +from ..layers.depthwise_conv1d import DepthwiseConv1D L2 = tf.keras.regularizers.l2(1e-6) @@ -143,20 +144,22 @@ def __init__(self, depth_multiplier=1, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conv_module", **kwargs): super(ConvModule, self).__init__(name=name, **kwargs) self.ln = tf.keras.layers.LayerNormalization() - self.pw_conv_1 = tf.keras.layers.Conv2D( + self.pw_conv_1 = tf.keras.layers.Conv1D( filters=2 * input_dim, kernel_size=1, strides=1, padding="valid", name=f"{name}_pw_conv_1", kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer ) self.glu = GLU(name=f"{name}_glu") - self.dw_conv = tf.keras.layers.DepthwiseConv2D( - kernel_size=(kernel_size, 1), strides=1, - padding="same", name=f"{name}_dw_conv", + self.dw_conv = DepthwiseConv1D( + kernel_size=(kernel_size), strides=1, + padding="same" if not streaming else "causal", + name=f"{name}_dw_conv", depth_multiplier=depth_multiplier, depthwise_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer @@ -167,7 +170,7 @@ def __init__(self, beta_regularizer=bias_regularizer ) self.swish = tf.keras.layers.Activation(tf.nn.swish, name=f"{name}_swish_activation") - self.pw_conv_2 = tf.keras.layers.Conv2D( + self.pw_conv_2 = tf.keras.layers.Conv1D( filters=input_dim, kernel_size=1, strides=1, padding="valid", name=f"{name}_pw_conv_2", kernel_regularizer=kernel_regularizer, @@ -178,15 +181,12 @@ def __init__(self, def call(self, inputs, training=False, **kwargs): outputs = self.ln(inputs, training=training) - B, T, E = shape_util.shape_list(outputs) - outputs = tf.reshape(outputs, [B, T, 1, E]) outputs = self.pw_conv_1(outputs, training=training) outputs = self.glu(outputs) outputs = self.dw_conv(outputs, training=training) outputs = self.bn(outputs, training=training) outputs = self.swish(outputs) outputs = self.pw_conv_2(outputs, training=training) - outputs = tf.reshape(outputs, [B, T, E]) outputs = self.do(outputs, training=training) outputs = self.res_add([inputs, outputs]) return outputs @@ -217,6 +217,7 @@ def __init__(self, depth_multiplier=1, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conformer_block", **kwargs): super(ConformerBlock, self).__init__(name=name, **kwargs) @@ -238,7 +239,8 @@ def __init__(self, dropout=dropout, name=f"{name}_conv_module", depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, - bias_regularizer=bias_regularizer + bias_regularizer=bias_regularizer, + streaming=streaming ) self.ffm2 = FFModule( input_dim=input_dim, dropout=dropout, @@ -286,6 +288,7 @@ def __init__(self, dropout=0.0, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name="conformer_encoder", **kwargs): super(ConformerEncoder, self).__init__(name=name, **kwargs) @@ -338,6 +341,7 @@ def __init__(self, depth_multiplier=depth_multiplier, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, + streaming=streaming, name=f"{name}_block_{i}" ) self.conformer_blocks.append(conformer_block) diff --git a/tensorflow_asr/models/layers/depthwise_conv1d.py b/tensorflow_asr/models/layers/depthwise_conv1d.py new file mode 100644 index 0000000000..9fc99c4e15 --- /dev/null +++ b/tensorflow_asr/models/layers/depthwise_conv1d.py @@ -0,0 +1,213 @@ +""" + This implementation comes from github: https://github.com/tensorflow/tensorflow/issues/36935 + Slight modifications have been made to support causal padding. +""" + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.keras import backend +from tensorflow.python.keras import constraints +from tensorflow.python.keras import initializers +from tensorflow.python.keras import regularizers +from tensorflow.python.keras.engine.input_spec import InputSpec +from tensorflow.python.keras.layers.convolutional import Conv1D + +from tensorflow.python.keras.utils import conv_utils +from tensorflow.python.keras.utils import tf_utils +from tensorflow.python.ops import array_ops +from tensorflow.python.util.tf_export import keras_export + + +class DepthwiseConv1D(Conv1D): + """Depthwise separable 1D convolution. + Depthwise Separable convolutions consist of performing + just the first step in a depthwise spatial convolution + (which acts on each input channel separately). + The `depth_multiplier` argument controls how many + output channels are generated per input channel in the depthwise step. + Arguments: + kernel_size: A single integer specifying the spatial + dimensions of the filters. + strides: A single integer specifying the strides + of the convolution. + Specifying any `stride` value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: one of `'valid'` or `'same'` (case-insensitive). + depth_multiplier: The number of depthwise convolution output channels + for each input channel. + The total number of depthwise convolution output + channels will be equal to `filters_in * depth_multiplier`. + data_format: A string, + one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, length, channels)` while `channels_first` + corresponds to inputs with shape + `(batch, channels, length)`. + The default is 'channels_last'. + activation: Activation function to use. + If you don't specify anything, no activation is applied + (ie. 'linear' activation: `a(x) = x`). + use_bias: Boolean, whether the layer uses a bias vector. + depthwise_initializer: Initializer for the depthwise kernel matrix. + bias_initializer: Initializer for the bias vector. + depthwise_regularizer: Regularizer function applied to + the depthwise kernel matrix. + bias_regularizer: Regularizer function applied to the bias vector. + activity_regularizer: Regularizer function applied to + the output of the layer (its 'activation'). + depthwise_constraint: Constraint function applied to + the depthwise kernel matrix. + bias_constraint: Constraint function applied to the bias vector. + Input shape: + 3D tensor with shape: + `[batch, channels, length]` if data_format='channels_first' + or 4D tensor with shape: + `[batch, length, channels]` if data_format='channels_last'. + Output shape: + 3D tensor with shape: + `[batch, filters, new_length]` if data_format='channels_first' + or 3D tensor with shape: + `[batch, new_length, filters]` if data_format='channels_last'. + `length` values might have changed due to padding. + """ + + def __init__(self, + kernel_size, + strides=1, + padding='valid', + depth_multiplier=1, + data_format=None, + activation=None, + use_bias=True, + depthwise_initializer='glorot_uniform', + bias_initializer='zeros', + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + **kwargs): + super(DepthwiseConv1D, self).__init__( + filters=None, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + activation=activation, + use_bias=use_bias, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + bias_constraint=bias_constraint, + # autocast=False, + **kwargs) + self.depth_multiplier = depth_multiplier + self.depthwise_initializer = initializers.get(depthwise_initializer) + self.depthwise_regularizer = regularizers.get(depthwise_regularizer) + self.depthwise_constraint = constraints.get(depthwise_constraint) + self.bias_initializer = initializers.get(bias_initializer) + + def build(self, input_shape): + if len(input_shape) < 3: + raise ValueError('Inputs to `DepthwiseConv1D` should have rank 3. ' + 'Received input shape:', str(input_shape)) + input_shape = tensor_shape.TensorShape(input_shape) + + # TODO(pj1989): replace with channel_axis = self._get_channel_axis() + if self.data_format == 'channels_last': + channel_axis = -1 + elif self.data_format == 'channels_first': + channel_axis = 1 + + if input_shape.dims[channel_axis].value is None: + raise ValueError('The channel dimension of the inputs to ' + '`DepthwiseConv1D` ' + 'should be defined. Found `None`.') + input_dim = int(input_shape[channel_axis]) + depthwise_kernel_shape = (self.kernel_size[0], input_dim, self.depth_multiplier) + + self.depthwise_kernel = self.add_weight( + shape=depthwise_kernel_shape, + initializer=self.depthwise_initializer, + name='depthwise_kernel', + regularizer=self.depthwise_regularizer, + constraint=self.depthwise_constraint) + + if self.use_bias: + self.bias = self.add_weight( + shape=(input_dim * self.depth_multiplier,), + initializer=self.bias_initializer, + name='bias', + regularizer=self.bias_regularizer, + constraint=self.bias_constraint) + else: + self.bias = None + # Set input spec. + self.input_spec = InputSpec(ndim=3, axes={channel_axis: input_dim}) + self.built = True + + def call(self, inputs): + if self.padding == 'causal': + inputs = array_ops.pad(inputs, self._compute_causal_padding(inputs)) + if self.data_format == 'channels_last': + spatial_start_dim = 1 + else: + spatial_start_dim = 2 + + # Explicitly broadcast inputs and kernels to 4D. + # TODO(fchollet): refactor when a native depthwise_conv2d op is available. + strides = self.strides * 2 + inputs = array_ops.expand_dims(inputs, spatial_start_dim) + depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0) + dilation_rate = (1,) + self.dilation_rate + + outputs = backend.depthwise_conv2d( + inputs, + depthwise_kernel, + strides=strides, + padding=self.padding if not self.padding == 'causal' else 'valid', + dilation_rate=dilation_rate, + data_format=self.data_format) + + if self.use_bias: + outputs = backend.bias_add( + outputs, + self.bias, + data_format=self.data_format) + + outputs = array_ops.squeeze(outputs, [spatial_start_dim]) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == 'channels_first': + length = input_shape[2] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == 'channels_last': + length = input_shape[1] + out_filters = input_shape[2] * self.depth_multiplier + + length = conv_utils.conv_output_length(length, self.kernel_size, + self.padding, + self.strides) + if self.data_format == 'channels_first': + return (input_shape[0], out_filters, length) + elif self.data_format == 'channels_last': + return (input_shape[0], length, out_filters) + + def get_config(self): + config = super(DepthwiseConv1D, self).get_config() + config.pop('filters') + config.pop('kernel_initializer') + config.pop('kernel_regularizer') + config.pop('kernel_constraint') + config['depth_multiplier'] = self.depth_multiplier + config['depthwise_initializer'] = initializers.serialize( + self.depthwise_initializer) + config['depthwise_regularizer'] = regularizers.serialize( + self.depthwise_regularizer) + config['depthwise_constraint'] = constraints.serialize( + self.depthwise_constraint) diff --git a/tensorflow_asr/models/transducer/conformer.py b/tensorflow_asr/models/transducer/conformer.py index b5d151e266..be5435444e 100644 --- a/tensorflow_asr/models/transducer/conformer.py +++ b/tensorflow_asr/models/transducer/conformer.py @@ -49,6 +49,7 @@ def __init__(self, joint_trainable: bool = True, kernel_regularizer=L2, bias_regularizer=L2, + streaming=False, name: str = "conformer", **kwargs): super(Conformer, self).__init__( @@ -67,6 +68,7 @@ def __init__(self, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, trainable=encoder_trainable, + streaming=streaming, name=f"{name}_encoder" ), vocabulary_size=vocabulary_size, diff --git a/tensorflow_asr/models/transducer/streaming_conformer.py b/tensorflow_asr/models/transducer/streaming_conformer.py new file mode 100644 index 0000000000..d26d4681b4 --- /dev/null +++ b/tensorflow_asr/models/transducer/streaming_conformer.py @@ -0,0 +1,134 @@ +# Copyright 2020 Huy Le Nguyen (@usimarit) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" http://arxiv.org/abs/1811.06621 """ + +import tensorflow as tf + +from ..layers.subsampling import TimeReduction +# from .transducer import Transducer +from ...utils import data_util, math_util +# from ...utils.utils import get_rnn, merge_two_last_dims, shape_list +from .conformer import Conformer + +L2 = tf.keras.regularizers.l2(1e-6) + +class StreamingConformer(Conformer): + """ + Attempt at implementing Streaming Conformer Transducer. (see: https://arxiv.org/pdf/2010.11395.pdf). + + Three main differences: + - Inputs are splits into chunks. + - Masking is used for MHSA to select the chunks to be used at each timestep. (Allows for parallel training.) + - Added parameter `streaming` to ConformerEncoder, ConformerBlock and ConvModule. Inside ConvModule, the layer DepthwiseConv2D has padding changed to "causal" when `streaming==True`. + + NOTE: Masking is applied just as regular masking along with the inputs. + """ + def __init__(self, + vocabulary_size: int, + encoder_subsampling: dict, + encoder_positional_encoding: str = "sinusoid", + encoder_dmodel: int = 144, + encoder_num_blocks: int = 16, + encoder_head_size: int = 36, + encoder_num_heads: int = 4, + encoder_mha_type: str = "relmha", + encoder_kernel_size: int = 32, + encoder_depth_multiplier: int = 1, + encoder_fc_factor: float = 0.5, + encoder_dropout: float = 0, + encoder_trainable: bool = True, + prediction_embed_dim: int = 512, + prediction_embed_dropout: int = 0, + prediction_num_rnns: int = 1, + prediction_rnn_units: int = 320, + prediction_rnn_type: str = "lstm", + prediction_rnn_implementation: int = 2, + prediction_layer_norm: bool = True, + prediction_projection_units: int = 0, + prediction_trainable: bool = True, + joint_dim: int = 1024, + joint_activation: str = "tanh", + prejoint_linear: bool = True, + postjoint_linear: bool = False, + joint_mode: str = "add", + joint_trainable: bool = True, + kernel_regularizer=L2, + bias_regularizer=L2, + name: str = "streaming_conformer", + **kwargs): + + self.streaming = True # Hardcoded value. Initializes Conformer with `streaming = True`. + super(StreamingConformer, self).__init__( + vocabulary_size=vocabulary_size, + encoder_subsampling=encoder_subsampling, + encoder_positional_encoding=encoder_positional_encoding, + encoder_dmodel=encoder_dmodel, + encoder_num_blocks=encoder_num_blocks, + encoder_head_size=encoder_head_size, + encoder_num_heads=encoder_num_heads, + encoder_mha_type=encoder_mha_type, + encoder_depth_multiplier=encoder_depth_multiplier, + encoder_kernel_size=encoder_kernel_size, + encoder_fc_factor=encoder_fc_factor, + encoder_dropout=encoder_dropout, + encoder_trainable=encoder_trainable, + prediction_embed_dim=prediction_embed_dim, + prediction_embed_dropout=prediction_embed_dropout, + prediction_num_rnns=prediction_num_rnns, + prediction_rnn_units=prediction_rnn_units, + prediction_rnn_type=prediction_rnn_type, + prediction_rnn_implementation=prediction_rnn_implementation, + prediction_layer_norm=prediction_layer_norm, + prediction_projection_units=prediction_projection_units, + prediction_trainable=prediction_trainable, + joint_dim=joint_dim, + joint_activation=joint_activation, + prejoint_linear=prejoint_linear, + postjoint_linear=postjoint_linear, + joint_mode=joint_mode, + joint_trainable=joint_trainable, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + streaming=self.streaming, + name=name, + **kwargs + ) + self.dmodel = encoder_dmodel + self.time_reduction_factor = self.encoder.conv_subsampling.time_reduction_factor + + def make(self, input_shape, prediction_shape=[None], batch_size=None): + inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size, dtype=tf.float32) + inputs_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + predictions = tf.keras.Input(shape=prediction_shape, batch_size=batch_size, dtype=tf.int32) + predictions_length = tf.keras.Input(shape=[], batch_size=batch_size, dtype=tf.int32) + mask = tf.keras.Input(shape=[None, None], batch_size=batch_size, dtype=tf.int32) + self( + data_util.create_inputs( + inputs=inputs, + inputs_length=inputs_length, + predictions=predictions, + predictions_length=predictions_length, + mask=mask + ), + training=False + ) + + def call(self, inputs, training=False, **kwargs): + enc = self.encoder(inputs["inputs"], training=training, mask=inputs["mask"], **kwargs) + pred = self.predict_net([inputs["predictions"], inputs["predictions_length"]], training=training, **kwargs) + logits = self.joint_net([enc, pred], training=training, **kwargs) + return data_util.create_logits( + logits=logits, + logits_length=math_util.get_reduced_length(inputs["inputs_length"], self.time_reduction_factor) + ) diff --git a/tensorflow_asr/utils/data_util.py b/tensorflow_asr/utils/data_util.py index 2bcdca8d4e..7c7d05f9f5 100644 --- a/tensorflow_asr/utils/data_util.py +++ b/tensorflow_asr/utils/data_util.py @@ -20,7 +20,8 @@ def create_inputs(inputs: tf.Tensor, inputs_length: tf.Tensor, predictions: tf.Tensor = None, - predictions_length: tf.Tensor = None) -> dict: + predictions_length: tf.Tensor = None, + mask: tf.Tensor = None) -> dict: data = { "inputs": inputs, "inputs_length": inputs_length, @@ -29,6 +30,8 @@ def create_inputs(inputs: tf.Tensor, data["predictions"] = predictions if predictions_length is not None: data["predictions_length"] = predictions_length + if mask is not None: + data["mask"] = mask return data