| import torch | |
| from espnet2.layers.label_aggregation import LabelAggregate | |
| class LabelProcessor(torch.nn.Module): | |
| """Label aggregator for speaker diarization """ | |
| def __init__( | |
| self, win_length: int = 512, hop_length: int = 128, center: bool = True | |
| ): | |
| super().__init__() | |
| self.label_aggregator = LabelAggregate(win_length, hop_length, center) | |
| def forward(self, input: torch.Tensor, ilens: torch.Tensor): | |
| """Forward. | |
| Args: | |
| input: (Batch, Nsamples, Label_dim) | |
| ilens: (Batch) | |
| Returns: | |
| output: (Batch, Frames, Label_dim) | |
| olens: (Batch) | |
| """ | |
| output, olens = self.label_aggregator(input, ilens) | |
| return output, olens | |