aladdin1995 commited on
Commit
aa6e6b1
·
verified ·
1 Parent(s): 58a469f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +329 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Gradio UI for PromptEnhancerV2
3
+
4
+ import os
5
+ from threading import Thread
6
+ from transformers import TextIteratorStreamer, AutoTokenizer
7
+ import time
8
+ import logging
9
+ import re
10
+ import torch
11
+ import gradio as gr
12
+
13
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
14
+
15
+ # 尝试导入 qwen_vl_utils,若失败则提供降级实现(返回空的图像/视频输入)
16
+ try:
17
+ from qwen_vl_utils import process_vision_info
18
+ except Exception:
19
+ def process_vision_info(messages):
20
+ return None, None
21
+
22
+ def replace_single_quotes(text):
23
+ pattern = r"\B'([^']*)'\B"
24
+ replaced_text = re.sub(pattern, r'"\1"', text)
25
+ replaced_text = replaced_text.replace("’", "”").replace("‘", "“")
26
+ return replaced_text
27
+
28
+ class PromptEnhancerV2:
29
+ def __init__(self, models_root_path, device_map="auto", torch_dtype="bfloat16"):
30
+ if not logging.getLogger(__name__).handlers:
31
+ logging.basicConfig(level=logging.INFO)
32
+ self.logger = logging.getLogger(__name__)
33
+
34
+ # dtype 兼容处理
35
+ if torch_dtype == "bfloat16":
36
+ dtype = torch.bfloat16
37
+ elif torch_dtype == "float16":
38
+ dtype = torch.float16
39
+ else:
40
+ dtype = torch.float32
41
+
42
+ self.model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
+ models_root_path,
44
+ torch_dtype=dtype,
45
+ attn_implementation="flash_attention_2",
46
+ device_map=device_map,
47
+ )
48
+ self.processor = AutoProcessor.from_pretrained(models_root_path)
49
+
50
+ @torch.inference_mode()
51
+ def predict(
52
+ self,
53
+ prompt_cot,
54
+ sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
55
+ temperature=0.0,
56
+ top_p=1.0,
57
+ max_new_tokens=2048,
58
+ device="cuda",
59
+ ):
60
+ org_prompt_cot = prompt_cot
61
+ try:
62
+ user_prompt_format = sys_prompt + "\n" + org_prompt_cot
63
+ messages = [
64
+ {
65
+ "role": "user",
66
+ "content": [
67
+ {"type": "text", "text": user_prompt_format},
68
+ ],
69
+ }
70
+ ]
71
+
72
+ text = self.processor.apply_chat_template(
73
+ messages, tokenize=False, add_generation_prompt=True
74
+ )
75
+ image_inputs, video_inputs = process_vision_info(messages)
76
+ inputs = self.processor(
77
+ text=[text],
78
+ images=image_inputs,
79
+ videos=video_inputs,
80
+ padding=True,
81
+ return_tensors="pt",
82
+ )
83
+ inputs = inputs.to(device)
84
+
85
+ # 注意:原始代码固定 do_sample=False,top_k=5, top_p=0.9,这里保持一致
86
+ generated_ids = self.model.generate(
87
+ **inputs,
88
+ max_new_tokens=2048, # 与原始代码保持一致(未使用 max_new_tokens 参数)
89
+ temperature=float(temperature),
90
+ do_sample=False,
91
+ top_k=5,
92
+ top_p=0.9
93
+ )
94
+ generated_ids_trimmed = [
95
+ out_ids[len(in_ids):]
96
+ for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
97
+ ]
98
+ output_text = self.processor.batch_decode(
99
+ generated_ids_trimmed,
100
+ skip_special_tokens=True,
101
+ clean_up_tokenization_spaces=False,
102
+ )
103
+ output_res = output_text[0]
104
+ assert output_res.count("think>") == 2
105
+ prompt_cot = output_res.split("think>")[-1]
106
+ if prompt_cot.startswith("\n"):
107
+ prompt_cot = prompt_cot[1:]
108
+ prompt_cot = replace_single_quotes(prompt_cot)
109
+ except Exception as e:
110
+ prompt_cot = org_prompt_cot
111
+ print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
112
+
113
+ return prompt_cot
114
+ @torch.inference_mode()
115
+ def predict_stream(
116
+ self,
117
+ prompt_cot,
118
+ sys_prompt="请根据用户的输入,生成思考过程的思维链并改写提示词:",
119
+ temperature=0.1,
120
+ top_p=1.0,
121
+ max_new_tokens=2048,
122
+ device="cuda",
123
+ ):
124
+ org_prompt_cot = prompt_cot
125
+
126
+ # 组装输入,同 predict
127
+ user_prompt_format = sys_prompt + "\n" + org_prompt_cot
128
+ messages = [{"role": "user", "content": [{"type": "text", "text": user_prompt_format}]}]
129
+ text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
130
+ image_inputs, video_inputs = process_vision_info(messages)
131
+ inputs = self.processor(
132
+ text=[text],
133
+ images=image_inputs,
134
+ videos=video_inputs,
135
+ padding=True,
136
+ return_tensors="pt",
137
+ )
138
+ inputs = inputs.to(device)
139
+
140
+ # 取得 tokenizer(大多数情况下 processor.tokenizer 就有;加一个后备以防万一��
141
+ tokenizer = getattr(self.processor, "tokenizer", None)
142
+ if tokenizer is None:
143
+ tokenizer = AutoTokenizer.from_pretrained(self.models_root_path, trust_remote_code=True)
144
+
145
+ streamer = TextIteratorStreamer(
146
+ tokenizer=tokenizer,
147
+ skip_special_tokens=True,
148
+ clean_up_tokenization_spaces=False,
149
+ )
150
+
151
+ gen_kwargs = dict(
152
+ **inputs,
153
+ max_new_tokens=max_new_tokens,
154
+ temperature=float(temperature),
155
+ do_sample=True, # 与原逻辑一致; 若要采样流式把这里改为 True
156
+ top_k=5,
157
+ top_p=0.9,
158
+ streamer=streamer,
159
+ )
160
+
161
+ # 子线程启动生成;主线程消费 streamer
162
+ thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
163
+ thread.start()
164
+
165
+ buffer = "" # 累积完整输出(含思考)
166
+ emitted = "" # 已对外输出的“重写提示词”部分
167
+ already_stripped_newline = False
168
+
169
+ try:
170
+ for piece in streamer:
171
+ buffer += piece
172
+ part = buffer.split('assistant')[-1]
173
+ delta = part[len(emitted):]
174
+ if delta:
175
+ emitted = part
176
+ yield emitted # 将中间结果送给前端
177
+ finally:
178
+ thread.join()
179
+
180
+ # 如果始终没等到第二个 think>,回退到原始 prompt
181
+ # if emitted.strip() == "":
182
+ # yield replace_single_quotes(org_prompt_cot)
183
+ try:
184
+ assert emitted.count("think>") == 2
185
+ prompt_cot = emitted.split("think>")[-1]
186
+ if prompt_cot.startswith("\n"):
187
+ prompt_cot = prompt_cot[1:]
188
+ prompt_cot = emitted.split('assistant')[-1] + '\n \n Recaption:'+replace_single_quotes(prompt_cot)
189
+ # prompt_cot = replace_single_quotes(prompt_cot)
190
+ yield prompt_cot
191
+ except Exception as e:
192
+ prompt_cot = org_prompt_cot
193
+ print(f"✗ Re-prompting failed, so we are using the original prompt. Error: {e}")
194
+ yield prompt_cot
195
+
196
+
197
+
198
+ # -------------------------
199
+ # Gradio app helpers
200
+ # -------------------------
201
+
202
+ DEFAULT_MODEL_PATH = os.environ.get("MODEL_OUTPUT_PATH", "PromptEnhancer/PromptEnhancer-32B")
203
+
204
+ def ensure_enhancer(state, model_path, device_map, torch_dtype):
205
+ """
206
+ state: dict or None
207
+ Returns: (state_dict)
208
+ """
209
+ need_reload = False
210
+ if state is None or not isinstance(state, dict):
211
+ need_reload = True
212
+ else:
213
+ prev_path = state.get("model_path")
214
+ prev_map = state.get("device_map")
215
+ prev_dtype = state.get("torch_dtype")
216
+ if prev_path != model_path or prev_map != device_map or prev_dtype != torch_dtype:
217
+ need_reload = True
218
+
219
+ if need_reload:
220
+ enhancer = PromptEnhancerV2(model_path, device_map=device_map, torch_dtype=torch_dtype)
221
+ return {"enhancer": enhancer, "model_path": model_path, "device_map": device_map, "torch_dtype": torch_dtype}
222
+ return state
223
+
224
+ def stream_single(prompt, sys_prompt, temperature, max_new_tokens, device,
225
+ model_path, device_map, torch_dtype, state):
226
+ if not prompt or not str(prompt).strip():
227
+ yield "", "请先输入提示词。", state
228
+ return
229
+
230
+ t0 = time.time()
231
+ state = ensure_enhancer(state, model_path, device_map, torch_dtype)
232
+ enhancer = state["enhancer"]
233
+
234
+ emitted = ""
235
+ try:
236
+ for chunk in enhancer.predict_stream(
237
+ prompt_cot=prompt,
238
+ sys_prompt=sys_prompt,
239
+ temperature=temperature,
240
+ max_new_tokens=max_new_tokens,
241
+ device=device
242
+ ):
243
+ emitted = chunk
244
+ info = f"已接收 {len(emitted)} 字符,用时 {time.time()-t0:.2f}s"
245
+ yield emitted, info, state
246
+ # 结束时再给一次最终状态(可选)
247
+ yield emitted, f"完成。总耗时 {time.time()-t0:.2f}s", state
248
+ except Exception as e:
249
+ yield "", f"推理失败:{e}", state
250
+
251
+
252
+ # 示例数据
253
+ test_list_zh = [
254
+ "第三人称视角,赛车在城市赛道上飞驰,左上角是小地图,地图下面是当前名次,右下角仪表盘显示当前速度。",
255
+ "韩系插画风女生头像,粉紫色短发+透明感腮红,侧光渲染。",
256
+ "点彩派,盛夏海滨,两位渔夫正在搬运木箱,三艘帆船停在岸边,对角线构图。",
257
+ "一幅由梵高绘制的梦境麦田,旋转的蓝色星云与燃烧的向日葵相纠缠。",
258
+ ]
259
+ test_list_en = [
260
+ "Create a painting depicting a 30-year-old white female white-collar worker on a business trip by plane.",
261
+ "Depicted in the anime style of Studio Ghibli, a girl stands quietly at the deck with a gentle smile.",
262
+ "Blue background, a lone girl gazes into the distant sea; her expression is sorrowful.",
263
+ "A blend of expressionist and vintage styles, drawing a building with colorful walls.",
264
+ "Paint a winter scene with crystalline ice hangings from an Antarctic research station.",
265
+ ]
266
+
267
+ with gr.Blocks(title="Prompt Enhancer_V2") as demo:
268
+ gr.Markdown("## 提示词重写器")
269
+ with gr.Row():
270
+ with gr.Column(scale=2):
271
+ model_path = gr.Textbox(
272
+ label="模型路径(本地或HF地址)",
273
+ value=DEFAULT_MODEL_PATH,
274
+ placeholder="/apdcephfs_jn3/share_302243908/aladdinwang/model_weight/cot_taurus_v6_50/global_step0",
275
+ )
276
+ device_map = gr.Dropdown(
277
+ choices=["auto", "cuda", "cpu"],
278
+ value="auto",
279
+ label="device_map(模型加载映射)"
280
+ )
281
+ torch_dtype = gr.Dropdown(
282
+ choices=["bfloat16", "float16", "float32"],
283
+ value="bfloat16",
284
+ label="torch_dtype"
285
+ )
286
+
287
+ with gr.Column(scale=3):
288
+ sys_prompt = gr.Textbox(
289
+ label="系统提示词(默认无需修改)",
290
+ value="请根据用户的输入,生成思考过程的思维链并改写提示词:",
291
+ lines=3
292
+ )
293
+ with gr.Row():
294
+ temperature = gr.Slider(0, 1, value=0.1, step=0.05, label="Temperature")
295
+ max_new_tokens = gr.Slider(16, 4096, value=2048, step=16, label="Max New Tokens(原代码未使用该参数)")
296
+ device = gr.Dropdown(choices=["cuda", "cpu"], value="cuda", label="推理device")
297
+
298
+ state = gr.State(value=None)
299
+
300
+ with gr.Tab("推理"):
301
+ with gr.Row():
302
+ with gr.Column(scale=2):
303
+ prompt = gr.Textbox(label="输入提示词", lines=6, placeholder="在此粘贴要改写的提示词...")
304
+ run_btn = gr.Button("生成重写", variant="primary")
305
+ gr.Examples(
306
+ examples=test_list_zh + test_list_en,
307
+ inputs=prompt,
308
+ label="示例"
309
+ )
310
+ with gr.Column(scale=3):
311
+ out_text = gr.Textbox(label="重写结果", lines=10)
312
+ out_info = gr.Markdown("准备就绪。")
313
+
314
+ run_btn.click(
315
+ stream_single,
316
+ inputs=[prompt, sys_prompt, temperature, max_new_tokens, device,
317
+ model_path, device_map, torch_dtype, state],
318
+ outputs=[out_text, out_info, state]
319
+ )
320
+
321
+ gr.Markdown(
322
+ "提示:如有任何问题可email联系:linqing1995@buaa.edu.cn"
323
+ )
324
+
325
+ # 为避免多并发导致显存爆,限制并发
326
+ # demo.queue(concurrency_count=1, max_size=10)
327
+ if __name__ == "__main__":
328
+ # demo.launch(server_name="0.0.0.0", server_port=8080, show_error=True)
329
+ demo.launch( show_error=True)