Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| import logging | |
| import math | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, TextStreamer, TextIteratorStreamer | |
| from transformers import StoppingCriteria, StoppingCriteriaList | |
| from transformers import AwqConfig, AutoModelForCausalLM | |
| from threading import Thread | |
| logger = logging.getLogger(__name__) | |
| class ThinkStoppingCriteria(StoppingCriteria): | |
| def __init__(self, tokenizer): | |
| self.tokenizer = tokenizer | |
| self.true_sequence = tokenizer("</think> true").input_ids[1:] # Skip first token | |
| self.false_sequence = tokenizer("</think> false").input_ids[1:] # Skip first token | |
| self.matched_sequence = None | |
| def __call__(self, input_ids, scores, **kwargs): | |
| for sequence in [self.true_sequence, self.false_sequence]: | |
| if input_ids.shape[1] >= len(sequence): | |
| if all((input_ids[0, -(len(sequence)-i)] == sequence[i] for i in range(len(sequence)))): | |
| self.matched_sequence = "</think> true" if sequence is self.true_sequence else "</think> false" | |
| return True | |
| return False | |
| class Rank1: | |
| def __init__( | |
| self, | |
| model_name_or_path: str = "", | |
| # set these just for demo, typically longer | |
| context_size: int = 4000, | |
| max_output_tokens: int = 1024, | |
| **kwargs, | |
| ): | |
| self.context_size = context_size | |
| self.max_output_tokens = max_output_tokens | |
| # Initialize tokenizer | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | |
| self.tokenizer.padding_side = "left" | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| # Cache commonly used token IDs | |
| self.true_token = self.tokenizer(" true", add_special_tokens=False).input_ids[0] | |
| self.false_token = self.tokenizer(" false", add_special_tokens=False).input_ids[0] | |
| # Load AWQ model on CPU initially | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_name_or_path, | |
| device_map="auto", | |
| trust_remote_code=True, | |
| attn_implementation="flash_attention_2" | |
| ) | |
| self.stopping_criteria = StoppingCriteriaList([ | |
| ThinkStoppingCriteria(self.tokenizer) | |
| ]) | |
| # Update generation config | |
| self.generation_config = GenerationConfig( | |
| max_new_tokens=max_output_tokens, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id | |
| ) | |
| # Create text streamer | |
| self.streamer = TextStreamer(self.tokenizer) | |
| # Simple generation config | |
| self.generation_config = GenerationConfig( | |
| max_new_tokens=max_output_tokens, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| eos_token_id=self.tokenizer.eos_token_id, | |
| stopping_sequences=["</think> true", "</think> false"] | |
| ) | |