| """Length bonus module.""" | |
| from typing import Any | |
| from typing import List | |
| from typing import Tuple | |
| import torch | |
| from espnet.nets.scorer_interface import BatchScorerInterface | |
| class LengthBonus(BatchScorerInterface): | |
| """Length bonus in beam search.""" | |
| def __init__(self, n_vocab: int): | |
| """Initialize class. | |
| Args: | |
| n_vocab (int): The number of tokens in vocabulary for beam search | |
| """ | |
| self.n = n_vocab | |
| def score(self, y, state, x): | |
| """Score new token. | |
| Args: | |
| y (torch.Tensor): 1D torch.int64 prefix tokens. | |
| state: Scorer state for prefix tokens | |
| x (torch.Tensor): 2D encoder feature that generates ys. | |
| Returns: | |
| tuple[torch.Tensor, Any]: Tuple of | |
| torch.float32 scores for next token (n_vocab) | |
| and None | |
| """ | |
| return torch.tensor([1.0], device=x.device, dtype=x.dtype).expand(self.n), None | |
| def batch_score( | |
| self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor | |
| ) -> Tuple[torch.Tensor, List[Any]]: | |
| """Score new token batch. | |
| Args: | |
| ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). | |
| states (List[Any]): Scorer states for prefix tokens. | |
| xs (torch.Tensor): | |
| The encoder feature that generates ys (n_batch, xlen, n_feat). | |
| Returns: | |
| tuple[torch.Tensor, List[Any]]: Tuple of | |
| batchfied scores for next token with shape of `(n_batch, n_vocab)` | |
| and next state list for ys. | |
| """ | |
| return ( | |
| torch.tensor([1.0], device=xs.device, dtype=xs.dtype).expand( | |
| ys.shape[0], self.n | |
| ), | |
| None, | |
| ) | |