| """Create mask for subsequent steps.""" | |
| def make_history_mask(xp, block): | |
| """Prepare the history mask. | |
| Args: | |
| block (ndarray): Block with dimensions: (B x S). | |
| Returns: | |
| ndarray, np.ndarray: History mask with dimensions (B, S, S). | |
| """ | |
| batch, length = block.shape | |
| arange = xp.arange(length) | |
| history_mask = (arange[None] <= arange[:, None])[ | |
| None, | |
| ] | |
| history_mask = xp.broadcast_to(history_mask, (batch, length, length)) | |
| return history_mask | |