Spaces:
Runtime error
Runtime error
File size: 4,337 Bytes
cbe2793 fc31c67 cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 fc31c67 cbe2793 fc31c67 cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 fc31c67 cbe2793 fc31c67 cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 d8fce5c cbe2793 fc31c67 cbe2793 fc31c67 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import torch
import torch.nn.functional as F
from transformers.generation import TopKLogitsWarper, TopPLogitsWarper
from ..utils.infer_utils import CustomRepetitionPenaltyLogitsProcessorRepeat
def infer_code(
models,
text,
spk_emb = None,
top_P = 0.7,
top_K = 20,
temperature = 0.3,
repetition_penalty = 1.05,
max_new_token = 2048,
stream=False,
**kwargs
):
device = next(models['gpt'].parameters()).device
if not isinstance(text, list):
text = [text]
if not isinstance(temperature, list):
temperature = [temperature] * models['gpt'].num_vq
if spk_emb is not None:
text = [f'[Stts][spk_emb]{i}[Ptts]' for i in text]
else:
text = [f'[Stts][empty_spk]{i}[Ptts]' for i in text]
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
input_ids = text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq)
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
inputs = {
'input_ids': input_ids,
'text_mask': text_mask,
'attention_mask': text_token['attention_mask'],
}
emb = models['gpt'].get_emb(**inputs)
if spk_emb is not None:
emb[inputs['input_ids'][..., 0] == models['tokenizer'].convert_tokens_to_ids('[spk_emb]')] = \
F.normalize(spk_emb.to(device).to(emb.dtype)[None].expand(len(text), -1), p=2.0, dim=1, eps=1e-12)
num_code = models['gpt'].emb_code[0].num_embeddings - 1
LogitsWarpers = []
if top_P is not None:
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
LogitsProcessors = []
if repetition_penalty is not None and repetition_penalty != 1:
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(\
repetition_penalty, num_code, 16))
result = models['gpt'].generate(
emb, inputs['input_ids'],
temperature = torch.tensor(temperature, device=device),
attention_mask = inputs['attention_mask'],
LogitsWarpers = LogitsWarpers,
LogitsProcessors = LogitsProcessors,
eos_token = num_code,
max_new_token = max_new_token,
infer_text = False,
stream = stream,
**kwargs
)
return result
def refine_text(
models,
text,
top_P = 0.7,
top_K = 20,
temperature = 0.7,
repetition_penalty = 1.0,
max_new_token = 384,
prompt = '',
**kwargs
):
device = next(models['gpt'].parameters()).device
if not isinstance(text, list):
text = [text]
assert len(text), 'text should not be empty'
text = [f"[Sbreak]{i}[Pbreak]{prompt}" for i in text]
text_token = models['tokenizer'](text, return_tensors='pt', add_special_tokens=False, padding=True).to(device)
text_mask = torch.ones(text_token['input_ids'].shape, dtype=bool, device=device)
inputs = {
'input_ids': text_token['input_ids'][...,None].expand(-1, -1, models['gpt'].num_vq),
'text_mask': text_mask,
'attention_mask': text_token['attention_mask'],
}
LogitsWarpers = []
if top_P is not None:
LogitsWarpers.append(TopPLogitsWarper(top_P, min_tokens_to_keep=3))
if top_K is not None:
LogitsWarpers.append(TopKLogitsWarper(top_K, min_tokens_to_keep=3))
LogitsProcessors = []
if repetition_penalty is not None and repetition_penalty != 1:
LogitsProcessors.append(CustomRepetitionPenaltyLogitsProcessorRepeat(repetition_penalty, len(models['tokenizer']), 16))
result = models['gpt'].generate(
models['gpt'].get_emb(**inputs), inputs['input_ids'],
temperature = torch.tensor([temperature,], device=device),
attention_mask = inputs['attention_mask'],
LogitsWarpers = LogitsWarpers,
LogitsProcessors = LogitsProcessors,
eos_token = torch.tensor(models['tokenizer'].convert_tokens_to_ids('[Ebreak]'), device=device)[None],
max_new_token = max_new_token,
infer_text = True,
stream = False,
**kwargs
)
return next(result)
|