| | import json, time, random, os |
| | import numpy as np |
| | import torch |
| | from torch.nn import functional as F |
| |
|
| | time_slot = {} |
| | time_ref = time.time_ns() |
| |
|
| | def record_time(name): |
| | if name not in time_slot: |
| | time_slot[name] = 1e20 |
| | tt = (time.time_ns() - time_ref) / 1e9 |
| | if tt < time_slot[name]: |
| | time_slot[name] = tt |
| |
|
| | class TOKENIZER(): |
| | def __init__(self, WORD_NAME, UNKNOWN_CHAR='\ue083'): |
| | if 'list' in str(type(WORD_NAME)): |
| | self.charMode = False |
| | if WORD_NAME[0] == WORD_NAME[1]: |
| | from transformers import PreTrainedTokenizerFast |
| | self.tokenizer = PreTrainedTokenizerFast(tokenizer_file=WORD_NAME[0]) |
| | else: |
| | from transformers import GPT2TokenizerFast |
| | self.tokenizer = GPT2TokenizerFast(WORD_NAME[0], WORD_NAME[1]) |
| | self.vocab_size = len(self.tokenizer) |
| | else: |
| | self.charMode = True |
| | with open(WORD_NAME + '.json', "r", encoding="utf-16") as result_file: |
| | self.word_table = json.load(result_file) |
| |
|
| | self.vocab_size = len(self.word_table) |
| |
|
| | self.stoi = {v: int(k) for k, v in self.word_table.items()} |
| | self.itos = {int(k): v for k, v in self.word_table.items()} |
| |
|
| | self.UNKNOWN_CHAR = self.stoi[UNKNOWN_CHAR] |
| |
|
| | def refine_context(self, context): |
| | context = context.strip().split('\n') |
| | for c in range(len(context)): |
| | context[c] = context[c].strip().strip('\u3000').strip('\r') |
| | context = list(filter(lambda c: c != '', context)) |
| | context = '\n' + ('\n'.join(context)).strip() |
| | if context == '': |
| | context = '\n' |
| | return context |
| |
|
| | def sample_logits(self, out, x, ctx_len, temperature=1.0, top_p_usual=None, top_p_newline=None): |
| | |
| | lastChar = int(x[-1]) |
| |
|
| | probs = F.softmax(out, dim=-1) |
| |
|
| | if self.charMode: |
| | if self.itos[lastChar] == '\n': |
| | top_p = top_p_newline |
| | else: |
| | top_p = top_p_usual |
| | else: |
| | top_p = top_p_usual |
| |
|
| | if os.environ["RWKV_RUN_DEVICE"] == "cpu": |
| | probs = probs.numpy() |
| | sorted_probs = np.sort(probs)[::-1] |
| | cumulative_probs = np.cumsum(sorted_probs) |
| | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
| | probs[probs < cutoff] = 0 |
| | if temperature != 1.0: |
| | probs = probs.pow(1.0 / temperature) |
| | probs = probs / np.sum(probs) |
| | out = np.random.choice(a=len(probs), p=probs) |
| | return out |
| | else: |
| | sorted_probs = torch.sort(probs, descending=True)[0] |
| | cumulative_probs = torch.cumsum(sorted_probs, dim=-1).cpu().numpy() |
| | cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)]) |
| | probs[probs < cutoff] = 0 |
| | if temperature != 1.0: |
| | probs = probs.pow(1.0 / temperature) |
| | out = torch.multinomial(probs, num_samples=1)[0] |
| | return out |
| |
|
| | def MaybeIsPrime(number): |
| | if FermatPrimalityTest(number) and MillerRabinPrimalityTest(number): |
| | return True |
| | else: |
| | return False |
| |
|
| |
|
| | def FermatPrimalityTest(number): |
| | if number > 1: |
| | for time in range(3): |
| | randomNumber = random.randint(2, number) - 1 |
| | if pow(randomNumber, number - 1, number) != 1: |
| | return False |
| | return True |
| | else: |
| | return False |
| |
|
| |
|
| | def MillerRabinPrimalityTest(number): |
| | if number == 2: |
| | return True |
| | elif number == 1 or number % 2 == 0: |
| | return False |
| | oddPartOfNumber = number - 1 |
| | timesTwoDividNumber = 0 |
| | while oddPartOfNumber % 2 == 0: |
| | oddPartOfNumber = oddPartOfNumber // 2 |
| | timesTwoDividNumber = timesTwoDividNumber + 1 |
| |
|
| | for time in range(3): |
| | while True: |
| | randomNumber = random.randint(2, number) - 1 |
| | if randomNumber != 0 and randomNumber != 1: |
| | break |
| |
|
| | randomNumberWithPower = pow(randomNumber, oddPartOfNumber, number) |
| |
|
| | if (randomNumberWithPower != 1) and (randomNumberWithPower != number - 1): |
| | iterationNumber = 1 |
| |
|
| | while (iterationNumber <= timesTwoDividNumber - 1) and (randomNumberWithPower != number - 1): |
| | randomNumberWithPower = pow(randomNumberWithPower, 2, number) |
| | iterationNumber = iterationNumber + 1 |
| | if randomNumberWithPower != (number - 1): |
| | return False |
| |
|
| | return True |
| |
|