| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import math |
| from typing import List |
|
|
| import torch |
| from sherpa import RnntConformerModel, greedy_search, modified_beam_search |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| LOG_EPS = math.log(1e-10) |
|
|
|
|
| @torch.no_grad() |
| def run_model_and_do_greedy_search( |
| model: RnntConformerModel, |
| features: List[torch.Tensor], |
| ) -> List[List[int]]: |
| """Run RNN-T model with the given features and use greedy search |
| to decode the output of the model. |
| |
| Args: |
| model: |
| The RNN-T model. |
| features: |
| A list of 2-D tensors. Each entry is of shape |
| (num_frames, feature_dim). |
| Returns: |
| Return a list-of-list containing the decoding token IDs. |
| """ |
| features_length = torch.tensor( |
| [f.size(0) for f in features], |
| dtype=torch.int64, |
| ) |
| features = pad_sequence( |
| features, |
| batch_first=True, |
| padding_value=LOG_EPS, |
| ) |
|
|
| device = model.device |
| features = features.to(device) |
| features_length = features_length.to(device) |
|
|
| encoder_out, encoder_out_length = model.encoder( |
| features=features, |
| features_length=features_length, |
| ) |
|
|
| hyp_tokens = greedy_search( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_length=encoder_out_length.cpu(), |
| ) |
| return hyp_tokens |
|
|
|
|
| @torch.no_grad() |
| def run_model_and_do_modified_beam_search( |
| model: RnntConformerModel, |
| features: List[torch.Tensor], |
| num_active_paths: int, |
| ) -> List[List[int]]: |
| """Run RNN-T model with the given features and use greedy search |
| to decode the output of the model. |
| |
| Args: |
| model: |
| The RNN-T model. |
| features: |
| A list of 2-D tensors. Each entry is of shape |
| (num_frames, feature_dim). |
| num_active_paths: |
| Used only when decoding_method is modified_beam_search. |
| It specifies number of active paths for each utterance. Due to |
| merging paths with identical token sequences, the actual number |
| may be less than "num_active_paths". |
| Returns: |
| Return a list-of-list containing the decoding token IDs. |
| """ |
| features_length = torch.tensor( |
| [f.size(0) for f in features], |
| dtype=torch.int64, |
| ) |
| features = pad_sequence( |
| features, |
| batch_first=True, |
| padding_value=LOG_EPS, |
| ) |
|
|
| device = model.device |
| features = features.to(device) |
| features_length = features_length.to(device) |
|
|
| encoder_out, encoder_out_length = model.encoder( |
| features=features, |
| features_length=features_length, |
| ) |
|
|
| hyp_tokens = modified_beam_search( |
| model=model, |
| encoder_out=encoder_out, |
| encoder_out_length=encoder_out_length.cpu(), |
| num_active_paths=num_active_paths, |
| ) |
| return hyp_tokens |
|
|