Spaces:
Paused
Paused
| import os | |
| import gc | |
| import json | |
| import torch | |
| from torch.nn import Module | |
| from transformers import PreTrainedModel, PreTrainedTokenizer | |
| from transformers import AutoModel | |
| from transformers.generation.logits_process import LogitsProcessor | |
| from typing import Dict, Union, Optional, Tuple | |
| def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: | |
| # transformer.word_embeddings 占用1层 | |
| # transformer.final_layernorm 和 lm_head 占用1层 | |
| # transformer.layers 占用 28 层 | |
| # 总共30层分配到num_gpus张卡上 | |
| num_trans_layers = 28 | |
| per_gpu_layers = 30 / num_gpus | |
| # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError | |
| # windows下 model.device 会被设置成 transformer.word_embeddings.device | |
| # linux下 model.device 会被设置成 lm_head.device | |
| # 在调用chat或者stream_chat时,input_ids会被放到model.device上 | |
| # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError | |
| # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 | |
| # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py | |
| # 仅此处做少许修改以支持ChatGLM3 | |
| device_map = { | |
| 'transformer.embedding.word_embeddings': 0, | |
| 'transformer.encoder.final_layernorm': 0, | |
| 'transformer.output_layer': 0, | |
| 'transformer.rotary_pos_emb': 0, | |
| 'lm_head': 0 | |
| } | |
| used = 2 | |
| gpu_target = 0 | |
| for i in range(num_trans_layers): | |
| if used >= per_gpu_layers: | |
| gpu_target += 1 | |
| used = 0 | |
| assert gpu_target < num_gpus | |
| device_map[f'transformer.encoder.layers.{i}'] = gpu_target | |
| used += 1 | |
| return device_map | |
| def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, | |
| device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: | |
| if num_gpus < 2 and device_map is None: | |
| model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() | |
| else: | |
| from accelerate import dispatch_model | |
| model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() | |
| if device_map is None: | |
| device_map = auto_configure_device_map(num_gpus) | |
| model = dispatch_model(model, device_map=device_map) | |
| return model | |
| class InvalidScoreLogitsProcessor(LogitsProcessor): | |
| def __call__( | |
| self, input_ids: torch.LongTensor, scores: torch.FloatTensor | |
| ) -> torch.FloatTensor: | |
| if torch.isnan(scores).any() or torch.isinf(scores).any(): | |
| scores.zero_() | |
| scores[..., 5] = 5e4 | |
| return scores | |
| def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: | |
| content = "" | |
| for response in output.split("<|assistant|>"): | |
| metadata, content = response.split("\n", maxsplit=1) | |
| if not metadata.strip(): | |
| content = content.strip() | |
| content = content.replace("[[训练时间]]", "2023年") | |
| else: | |
| if use_tool: | |
| content = "\n".join(content.split("\n")[1:-1]) | |
| def tool_call(**kwargs): | |
| return kwargs | |
| parameters = eval(content) | |
| content = { | |
| "name": metadata.strip(), | |
| "arguments": json.dumps(parameters, ensure_ascii=False) | |
| } | |
| else: | |
| content = { | |
| "name": metadata.strip(), | |
| "content": content | |
| } | |
| return content | |
| def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): | |
| messages = params["messages"] | |
| functions = params["functions"] | |
| temperature = float(params.get("temperature", 1.0)) | |
| repetition_penalty = float(params.get("repetition_penalty", 1.0)) | |
| top_p = float(params.get("top_p", 1.0)) | |
| max_new_tokens = int(params.get("max_tokens", 256)) | |
| echo = params.get("echo", True) | |
| messages = process_chatglm_messages(messages, functions=functions) | |
| query, role = messages[-1]["content"], messages[-1]["role"] | |
| inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) | |
| inputs = inputs.to(model.device) | |
| input_echo_len = len(inputs["input_ids"][0]) | |
| if input_echo_len >= model.config.seq_length: | |
| print(f"Input length larger than {model.config.seq_length}") | |
| eos_token_id = [ | |
| tokenizer.eos_token_id, | |
| tokenizer.get_command("<|user|>"), | |
| ] | |
| gen_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "do_sample": True if temperature > 1e-5 else False, | |
| "top_p": top_p, | |
| "repetition_penalty": repetition_penalty, | |
| "logits_processor": [InvalidScoreLogitsProcessor()], | |
| } | |
| if temperature > 1e-5: | |
| gen_kwargs["temperature"] = temperature | |
| total_len = 0 | |
| for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs): | |
| total_ids = total_ids.tolist()[0] | |
| total_len = len(total_ids) | |
| if echo: | |
| output_ids = total_ids[:-1] | |
| else: | |
| output_ids = total_ids[input_echo_len:-1] | |
| response = tokenizer.decode(output_ids) | |
| if response and response[-1] != "�": | |
| response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) | |
| yield { | |
| "text": response, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| "finish_reason": "function_call" if stop_found else None, | |
| } | |
| if stop_found: | |
| break | |
| # Only last stream result contains finish_reason, we set finish_reason as stop | |
| ret = { | |
| "text": response, | |
| "usage": { | |
| "prompt_tokens": input_echo_len, | |
| "completion_tokens": total_len - input_echo_len, | |
| "total_tokens": total_len, | |
| }, | |
| "finish_reason": "stop", | |
| } | |
| yield ret | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def process_chatglm_messages(messages, functions=None): | |
| _messages = messages | |
| messages = [] | |
| if functions: | |
| messages.append( | |
| { | |
| "role": "system", | |
| "content": "Answer the following questions as best as you can. You have access to the following tools:", | |
| "tools": functions | |
| } | |
| ) | |
| for m in _messages: | |
| role, content, func_call = m.role, m.content, m.function_call | |
| if role == "function": | |
| messages.append( | |
| { | |
| "role": "observation", | |
| "content": content | |
| } | |
| ) | |
| elif role == "assistant" and func_call is not None: | |
| for response in content.split("<|assistant|>"): | |
| metadata, sub_content = response.split("\n", maxsplit=1) | |
| messages.append( | |
| { | |
| "role": role, | |
| "metadata": metadata, | |
| "content": sub_content.strip() | |
| } | |
| ) | |
| else: | |
| messages.append({"role": role, "content": content}) | |
| return messages | |
| def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): | |
| for response in generate_stream_chatglm3(model, tokenizer, params): | |
| pass | |
| return response | |
| def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]: | |
| stop_found = False | |
| for string in stop_strings: | |
| idx = reply.find(string) | |
| if idx != -1: | |
| reply = reply[:idx] | |
| stop_found = True | |
| break | |
| if not stop_found: | |
| # If something like "\nYo" is generated just before "\nYou: is completed, trim it | |
| for string in stop_strings: | |
| for j in range(len(string) - 1, 0, -1): | |
| if reply[-j:] == string[:j]: | |
| reply = reply[:-j] | |
| break | |
| else: | |
| continue | |
| break | |
| return reply, stop_found | |