| import functools |
| import logging |
|
|
| import torch |
| import torch.nn.functional as F |
| import math |
| import numpy as np |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def get_device_of(tensor): |
| """This function returns the device of the tensor |
| refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py |
| |
| Arguments: |
| tensor {tensor} -- tensor |
| |
| Returns: |
| int -- device |
| """ |
|
|
| if not tensor.is_cuda: |
| return -1 |
| else: |
| return tensor.get_device() |
|
|
|
|
| def get_range_vector(size, device): |
| """This function returns a range vector with the desired size, starting at 0 |
| the CUDA implementation is meant to avoid copy data from CPU to GPU |
| refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py |
| |
| Arguments: |
| size {int} -- the size of range |
| device {int} -- device |
| |
| Returns: |
| torch.Tensor -- range vector |
| """ |
|
|
| if device > -1: |
| return torch.cuda.LongTensor(size, device=device).fill_(1).cumsum(0) - 1 |
| else: |
| return torch.arange(0, size, dtype=torch.long) |
|
|
|
|
| def flatten_and_batch_shift_indices(indices, sequence_length): |
| """This function returns a vector that correctly indexes into the flattened target, |
| the sequence length of the target must be provided to compute the appropriate offsets. |
| refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py |
| |
| Arguments: |
| indices {tensor} -- index tensor |
| sequence_length {int} -- sequence length |
| |
| Returns: |
| tensor -- offset index tensor |
| """ |
|
|
| |
| if torch.max(indices) >= sequence_length or torch.min(indices) < 0: |
| raise RuntimeError("All elements in indices should be in range (0, {})".format(sequence_length - 1)) |
| offsets = get_range_vector(indices.size(0), get_device_of(indices)) * sequence_length |
| for _ in range(len(indices.size()) - 1): |
| offsets = offsets.unsqueeze(1) |
|
|
| |
| offset_indices = indices + offsets |
|
|
| |
| offset_indices = offset_indices.view(-1) |
| return offset_indices |
|
|
|
|
| def batched_index_select(target, indices, flattened_indices=None): |
| """This function returns selected values in the target with respect to the provided indices, |
| which have size ``(batch_size, d_1, ..., d_n, embedding_size)`` |
| refer to https://github.com/allenai/allennlp/blob/master/allennlp/nn/util.py |
| |
| Arguments: |
| target {torch.Tensor} -- target tensor |
| indices {torch.LongTensor} -- index tensor |
| |
| Keyword Arguments: |
| flattened_indices {Optional[torch.LongTensor]} -- flattened index tensor (default: {None}) |
| |
| Returns: |
| torch.Tensor -- selected tensor |
| """ |
|
|
| if flattened_indices is None: |
| |
| flattened_indices = flatten_and_batch_shift_indices(indices, target.size(1)) |
|
|
| |
| flattened_target = target.view(-1, target.size(-1)) |
|
|
| |
| flattened_selected = flattened_target.index_select(0, flattened_indices) |
| selected_shape = list(indices.size()) + [target.size(-1)] |
| |
| selected_targets = flattened_selected.view(*selected_shape) |
| return selected_targets |
|
|
|
|
| def get_padding_vector(size, dtype, device): |
| """This function initializes padding unit |
| |
| Arguments: |
| size {int} -- padding unit size |
| dtype {torch.dtype} -- dtype |
| device {int} -- device = -1 if cpu, device >= 0 if gpu |
| |
| Returns: |
| tensor -- padding tensor |
| """ |
|
|
| pad = torch.zeros(size, dtype=dtype) |
| if device > -1: |
| pad = pad.cuda(device=device, non_blocking=True) |
| return pad |
|
|
|
|
| def array2tensor(array, dtype, device): |
| """This function transforms numpy array to tensor |
| |
| Arguments: |
| array {numpy.array} -- numpy array |
| dtype {torch.dtype} -- torch dtype |
| device {int} -- device = -1 if cpu, device >= 0 if gpu |
| |
| Returns: |
| tensor -- tensor |
| """ |
| tensor = torch.as_tensor(array, dtype=dtype) |
| if device > -1: |
| tensor = tensor.cuda(device=device, non_blocking=True) |
| return tensor |
|
|
|
|
| def gelu(x): |
| """Implementation of the gelu activation function. |
| For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): |
| 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) |
| Also see https://arxiv.org/abs/1606.08415 |
| refer to: https://github.com/huggingface/pytorch-transformers/blob/master/pytorch_transformers/modeling_bert.py |
| """ |
|
|
| return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) |
|
|
|
|
| def pad_vecs(vecs, padding_size, dtype, device): |
| """This function pads vectors for batch |
| |
| Arguments: |
| vecs {list} -- vector list |
| padding_size {int} -- padding dims |
| dtype {torch.dtype} -- dtype |
| device {int} -- device = -1 if cpu, device >= 0 if gpu |
| |
| Returns: |
| tensor -- padded vectors |
| """ |
| max_length = max(len(vec) for vec in vecs) |
|
|
| if max_length == 0: |
| pad_vecs = torch.cat([get_padding_vector((1, padding_size), dtype, device).unsqueeze(0) for _ in vecs], 0) |
| return pad_vecs |
|
|
| pad_vecs = [] |
| for vec in vecs: |
| pad_vec = torch.cat(vec + [get_padding_vector((1, padding_size), dtype, device)] * (max_length - len(vec)), |
| 0).unsqueeze(0) |
|
|
| assert pad_vec.size() == (1, max_length, padding_size), "the size of pad vector is not correct" |
|
|
| pad_vecs.append(pad_vec) |
| return torch.cat(pad_vecs, 0) |
|
|
|
|
| def get_bilstm_minus(batch_seq_encoder_repr, span_list, seq_lens): |
| """This function gets span representation using bilstm minus |
| |
| Arguments: |
| batch_seq_encoder_repr {list} -- batch sequence encoder representation |
| span_list {list} -- span list |
| seq_lens {list} -- sequence length list |
| |
| Returns: |
| tensor -- span representation vector |
| """ |
|
|
| assert len(batch_seq_encoder_repr) == len( |
| span_list), "the length of batch seq encoder repr is not equal to span list's length" |
|
|
| assert len(span_list) == len(seq_lens), "the length of span list is not equal to batch seq lens's length" |
|
|
| hidden_size = batch_seq_encoder_repr.size(-1) |
| span_vecs = [] |
| for seq_encoder_repr, (s, e), seq_len in zip(batch_seq_encoder_repr, span_list, seq_lens): |
| rnn_output = seq_encoder_repr[:seq_len] |
| forward_rnn_output, backward_rnn_output = rnn_output.split(hidden_size // 2, 1) |
| forward_span_vec = get_forward_segment(forward_rnn_output, s, e, get_device_of(forward_rnn_output)) |
| backward_span_vec = get_backward_segment(backward_rnn_output, s, e, get_device_of(backward_rnn_output)) |
| span_vec = torch.cat([forward_span_vec, backward_span_vec], 0).unsqueeze(0) |
| span_vecs.append(span_vec) |
| return torch.cat(span_vecs, 0) |
|
|
|
|
| def get_forward_segment(forward_rnn_output, s, e, device): |
| """This function gets span representaion in forward rnn |
| |
| Arguments: |
| forward_rnn_output {tensor} -- forward rnn output |
| s {int} -- span start |
| e {int} -- span end |
| device {int} -- device |
| |
| Returns: |
| tensor -- span representaion vector |
| """ |
|
|
| seq_len, hidden_size = forward_rnn_output.size() |
| if s >= e: |
| vec = torch.zeros(hidden_size, dtype=forward_rnn_output.dtype) |
|
|
| if device > -1: |
| vec = vec.cuda(device=device, non_blocking=True) |
| return vec |
|
|
| if s == 0: |
| return forward_rnn_output[e - 1] |
| return forward_rnn_output[e - 1] - forward_rnn_output[s - 1] |
|
|
|
|
| def get_backward_segment(backward_rnn_output, s, e, device): |
| """This function gets span representaion in backward rnn |
| |
| Arguments: |
| forward_rnn_output {tensor} -- backward rnn output |
| s {int} -- span start |
| e {int} -- span end |
| device {int} -- device |
| |
| Returns: |
| tensor -- span representaion vector |
| """ |
|
|
| seq_len, hidden_size = backward_rnn_output.size() |
| if s >= e: |
| vec = torch.zeros(hidden_size, dtype=backward_rnn_output.dtype) |
|
|
| if device > -1: |
| vec = vec.cuda(device=device, non_blocking=True) |
| return vec |
|
|
| if e == seq_len: |
| return backward_rnn_output[s] |
| return backward_rnn_output[s] - backward_rnn_output[e] |
|
|
|
|
| def get_dist_vecs(span_list, max_sent_len, device): |
| """This function gets distance embedding |
| |
| Arguments: |
| span_list {list} -- span list |
| |
| Returns: |
| tensor -- distance embedding vector |
| """ |
|
|
| dist_vecs = [] |
| for s, e in span_list: |
| assert s <= e, "span start is greater than end" |
|
|
| vec = torch.Tensor(np.eye(max_sent_len)[e - s]) |
| if device > -1: |
| vec = vec.cuda(device=device, non_blocking=True) |
|
|
| dist_vecs.append(vec) |
|
|
| return torch.stack(dist_vecs) |
|
|
|
|
| def get_conv_vecs(batch_token_repr, span_list, span_batch_size, conv_layer): |
| """This funciton gets span vector representation through convolution layer |
| |
| Arguments: |
| batch_token_repr {list} -- batch token representation |
| span_list {list} -- span list |
| span_batch_size {int} -- span convolutuion batch size |
| conv_layer {nn.Module} -- convolution layer |
| |
| Returns: |
| tensor -- conv vectors |
| """ |
|
|
| assert len(batch_token_repr) == len(span_list), "the length of batch token repr is not equal to span list's length" |
|
|
| span_vecs = [] |
| for token_repr, (s, e) in zip(batch_token_repr, span_list): |
| if s == e: |
| span_vecs.append([]) |
| continue |
|
|
| span_vecs.append(list(token_repr[s:e].split(1))) |
|
|
| span_conv_vecs = [] |
| for id in range(0, len(span_vecs), span_batch_size): |
| span_pad_vecs = pad_vecs(span_vecs[id:id + span_batch_size], conv_layer.get_input_dims(), |
| batch_token_repr[0].dtype, get_device_of(batch_token_repr[0])) |
| span_conv_vecs.append(conv_layer(span_pad_vecs)) |
| return torch.cat(span_conv_vecs, dim=0) |
|
|
|
|
| def get_n_trainable_parameters(model): |
| """This function calculates the number of trainable parameters |
| of the model |
| |
| Arguments: |
| model {nn.Module} -- model |
| |
| Returns: |
| int -- the number of trainable parameters of the model |
| """ |
|
|
| cnt = 0 |
| for param in list(model.parameters()): |
| if param.requires_grad: |
| cnt += functools.reduce(lambda x, y: x * y, list(param.size()), 1) |
| return cnt |
|
|
|
|
| def js_div(p, q, reduction='batchmean'): |
| """js_div caculate Jensen Shannon Divergence (JSD). |
| |
| Args: |
| p (tensor): distribution p |
| q (tensor): distribution q |
| reduction (str, optional): reduction. Defaults to 'batchmean'. |
| |
| Returns: |
| tensor: JS divergence |
| """ |
|
|
| m = 0.5 * (p + q) |
| return (F.kl_div(p, m, reduction=reduction) + F.kl_div(q, m, reduction=reduction)) * 0.5 |
|
|
|
|
| def load_weight_from_pretrained_model(model, pretrained_state_dict, prefix=""): |
| """load_weight_from_pretrained_model This function loads weight from pretrained model. |
| |
| Arguments: |
| model {nn.Module} -- model |
| pretrained_state_dict {dict} -- state dict of pretrained model |
| |
| Keyword Arguments: |
| prefix {str} -- prefix for pretrained model (default: {""}) |
| """ |
|
|
| model_state_dict = model.state_dict() |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| filtered_state_dict = {} |
| for k, v in model_state_dict.items(): |
| if 'decoder' in k: |
| continue |
| |
| |
| k = k.split('.') |
| for candi_name in ['.'.join(k), '.'.join(k[1:]), '.'.join(k[2:])]: |
| if candi_name in pretrained_state_dict and v.size() == pretrained_state_dict[candi_name].size(): |
| filtered_state_dict['.'.join(k)] = pretrained_state_dict[candi_name] |
| break |
|
|
| candi_name = prefix + candi_name |
| if candi_name in pretrained_state_dict and v.size() == pretrained_state_dict[candi_name].size(): |
| filtered_state_dict['.'.join(k)] = pretrained_state_dict[candi_name] |
| break |
|
|
| |
| |
|
|
| logger.info("Load weights parameters:") |
| for name in filtered_state_dict: |
| logger.info(name) |
|
|
| model_state_dict.update(filtered_state_dict) |
| model.load_state_dict(model_state_dict) |
|
|
|
|
| def clone_weights(first_module, second_module): |
| """This function clones(ties) weights from first module to second module |
| refers to: https://huggingface.co/transformers/v1.2.0/_modules/pytorch_transformers/modeling_utils.html#PreTrainedModel |
| |
| Arguments: |
| first_module {nn.Module} -- first module |
| second_module {nn.Module} -- second module |
| """ |
|
|
| first_module.weight = second_module.weight |
|
|
| if hasattr(first_module, 'bias') and first_module.bias is not None: |
| first_module.bias.data = torch.nn.functional.pad(first_module.bias.data, |
| (0, first_module.weight.shape[0] - first_module.bias.shape[0]), |
| 'constant', 0) |
|
|