jiuhai commited on
Commit
e1e0641
·
verified ·
1 Parent(s): 6b4452f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -318
app.py CHANGED
@@ -1,326 +1,123 @@
1
- import gradio as gr
2
- import os
3
- import torch
4
- import spaces
5
-
6
- from llava import conversation as conversation_lib
7
- from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
8
- from llava.conversation import conv_templates, SeparatorStyle
9
- from llava.model.builder import load_pretrained_model
10
- from llava.utils import disable_torch_init
11
- from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, process_images
12
 
 
13
  from PIL import Image
14
- import argparse
15
-
16
- from transformers import TextIteratorStreamer
17
- from threading import Thread
18
-
19
- import subprocess
20
- # Install flash attention, skipping CUDA build if necessary
21
- subprocess.run(
22
- "pip install flash-attn --no-build-isolation",
23
- env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"},
24
- shell=True,
25
- )
26
-
27
-
28
- # os.environ['GRADIO_TEMP_DIR'] = './gradio_tmp'
29
- no_change_btn = gr.Button()
30
- enable_btn = gr.Button(interactive=True)
31
- disable_btn = gr.Button(interactive=False)
32
-
33
- argparser = argparse.ArgumentParser()
34
- argparser.add_argument("--server_name", default="0.0.0.0", type=str)
35
- argparser.add_argument("--port", default="6324", type=str)
36
- argparser.add_argument("--model-path", default="jiuhai/florence-phi-ms", type=str)
37
- argparser.add_argument("--model-base", type=str, default=None)
38
- argparser.add_argument("--num-gpus", type=int, default=1)
39
- argparser.add_argument("--conv-mode", type=str, default="llama3")
40
- argparser.add_argument("--temperature", type=float, default=0.2)
41
- argparser.add_argument("--max-new-tokens", type=int, default=512)
42
- argparser.add_argument("--num_frames", type=int, default=16)
43
- argparser.add_argument("--load-8bit", action="store_true")
44
- argparser.add_argument("--load-4bit", action="store_true")
45
- argparser.add_argument("--debug", action="store_true")
46
-
47
- args = argparser.parse_args()
48
- model_path = args.model_path
49
- conv_mode = args.conv_mode
50
- filt_invalid="cut"
51
- model_name = get_model_name_from_path(args.model_path)
52
- model_kwargs = {
53
- "trust_remote_code": True,
54
- "torch_dtype": torch.bfloat16,
55
- "attn_implementation": "eager"
56
- }
57
- tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, device_map="cuda:0", **model_kwargs)
58
- our_chatbot = None
59
-
60
- def upvote_last_response(state):
61
- return ("",) + (disable_btn,) * 3
62
-
63
-
64
- def downvote_last_response(state):
65
- return ("",) + (disable_btn,) * 3
66
-
67
-
68
- def flag_last_response(state):
69
- return ("",) + (disable_btn,) * 3
70
-
71
- def clear_history():
72
- state =conv_templates[conv_mode].copy()
73
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
74
-
75
- def add_text(state, imagebox, textbox, image_process_mode):
76
- if state is None:
77
- state = conv_templates[conv_mode].copy()
78
-
79
- if imagebox is not None:
80
- textbox = DEFAULT_IMAGE_TOKEN + '\n' + textbox
81
- image = Image.open(imagebox).convert('RGB')
82
-
83
- if imagebox is not None:
84
- textbox = (textbox, image, image_process_mode)
85
-
86
- state.append_message(state.roles[0], textbox)
87
- state.append_message(state.roles[1], None)
88
-
89
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
90
-
91
- def delete_text(state, image_process_mode):
92
- state.messages[-1][-1] = None
93
- prev_human_msg = state.messages[-2]
94
- if type(prev_human_msg[1]) in (tuple, list):
95
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
96
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
97
-
98
- def regenerate(state, image_process_mode):
99
- state.messages[-1][-1] = None
100
- prev_human_msg = state.messages[-2]
101
- if type(prev_human_msg[1]) in (tuple, list):
102
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
103
- state.skip_next = False
104
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
105
-
106
- @spaces.GPU
107
- def generate(state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens):
108
- prompt = state.get_prompt()
109
- images = state.get_images(return_pil=True)
110
- #prompt, image_args = process_image(prompt, images)
111
-
112
- ori_prompt = prompt
113
- num_image_tokens = 0
114
-
115
- if images is not None and len(images) > 0:
116
- if len(images) > 0:
117
- if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
118
- raise ValueError("Number of images does not match number of <image> tokens in prompt")
119
-
120
- #images = [load_image_from_base64(image) for image in images]
121
- image_sizes = [image.size for image in images]
122
- images = process_images(images, image_processor, model.config)
123
-
124
- if type(images) is list:
125
- images = [image.to(model.device, dtype=torch.float16) for image in images]
126
- else:
127
- images = images.to(model.device, dtype=torch.float16)
128
- else:
129
- images = None
130
- image_sizes = None
131
- image_args = {"images": images, "image_sizes": image_sizes}
132
- else:
133
- images = None
134
- image_args = {}
135
-
136
- max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
137
- max_new_tokens = 512
138
- do_sample = True if temperature > 0.001 else False
139
- stop_str = state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2
140
-
141
- input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
142
- streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=15)
143
-
144
- max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)
145
-
146
- if max_new_tokens < 1:
147
- # yield json.dumps({"text": ori_prompt + "Exceeds max token length. Please start a new conversation, thanks.", "error_code": 0}).encode() + b"\0"
148
- return
149
-
150
- thread = Thread(target=model.generate, kwargs=dict(
151
- inputs=input_ids,
152
- do_sample=do_sample,
153
- temperature=temperature,
154
- top_p=top_p,
155
- max_new_tokens=max_new_tokens,
156
- streamer=streamer,
157
- use_cache=True,
158
- pad_token_id=tokenizer.eos_token_id,
159
- eos_token_id=[32007],
160
- **image_args
161
- ))
162
- thread.start()
163
- generated_text = ''
164
- for new_text in streamer:
165
- new_text = new_text.replace('<|end|>', "")
166
- generated_text += new_text
167
- if generated_text.endswith(stop_str):
168
- generated_text = generated_text[:-len(stop_str)]
169
- state.messages[-1][-1] = generated_text
170
- yield (state, state.to_gradio_chatbot(), "", None) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn)
171
-
172
- yield (state, state.to_gradio_chatbot(), "", None) + (enable_btn,) * 5
173
-
174
- torch.cuda.empty_cache()
175
-
176
- txt = gr.Textbox(
177
- scale=4,
178
- show_label=False,
179
- placeholder="Enter text and press enter.",
180
- container=False,
181
- )
182
-
183
-
184
- # title_markdown = ("""
185
- # # llava: Exploring The Design Space for Multimodal LLMs with Mixture of Encoders
186
- # [[Code](https://github.com/NVlabs/llava)] [[Model](https://huggingface.co/NVllava)] | 📚 [[Arxiv](https://arxiv.org/pdf/2408.15998)]]
187
- # """)
188
-
189
- title_markdown = ("""
190
- # Florence-phi
191
- """)
192
-
193
- tos_markdown = ("""
194
- ### Terms of use
195
- By using this service, users are required to agree to the following terms:
196
- The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
197
- Please click the "Flag" button if you get any inappropriate answer! We will collect those to keep improving our moderator.
198
- For an optimal experience, please use desktop computers for this demo, as mobile devices may compromise its quality.
199
- """)
200
-
201
-
202
- learn_more_markdown = ("""
203
- ### License
204
- The service is a research preview intended for non-commercial use only, subject to the. Please contact us if you find any potential violation.
205
- """)
206
-
207
- block_css = """
208
- #buttons button {
209
- min-width: min(120px,100%);
210
- }
211
- """
212
-
213
- textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False)
214
- with gr.Blocks(title="llava", theme=gr.themes.Default(), css=block_css) as demo:
215
- state = gr.State()
216
-
217
- gr.Markdown(title_markdown)
218
 
219
  with gr.Row():
220
  with gr.Column(scale=3):
221
- imagebox = gr.Image(label="Input Image", type="filepath")
222
- image_process_mode = gr.Radio(
223
- ["Crop", "Resize", "Pad", "Default"],
224
- value="Default",
225
- label="Preprocess for non-square image", visible=False)
226
-
227
- cur_dir = os.path.dirname(os.path.abspath(__file__))
228
- # gr.Examples(examples=[
229
- # [f"{cur_dir}/assets/health-insurance.png", "Under which circumstances do I need to be enrolled in mandatory health insurance if I am an international student?"],
230
- # [f"{cur_dir}/assets/leasing-apartment.png", "I don't have any 3rd party renter's insurance now. Do I need to get one for myself?"],
231
- # [f"{cur_dir}/assets/nvidia.jpeg", "Who is the person in the middle?"],
232
- # [f"{cur_dir}/assets/animal-compare.png", "Are these two pictures showing the same kind of animal?"],
233
- # [f"{cur_dir}/assets/georgia-tech.jpeg", "Where is this photo taken?"]
234
- # ], inputs=[imagebox, textbox], cache_examples=False)
235
-
236
- gr.Examples(examples=[
237
- [f"{cur_dir}/assets/animal-compare.png", "Are these two pictures showing the same kind of animal?"]
238
- ], inputs=[imagebox, textbox], cache_examples=False)
239
-
240
- with gr.Accordion("Parameters", open=False) as parameter_row:
241
- temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",)
242
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",)
243
- max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",)
244
-
245
- with gr.Column(scale=8):
246
- chatbot = gr.Chatbot(
247
- elem_id="chatbot",
248
- label="llava Chatbot",
249
- height=650,
250
- layout="panel",
251
  )
252
- with gr.Row():
253
- with gr.Column(scale=8):
254
- textbox.render()
255
- with gr.Column(scale=1, min_width=50):
256
- submit_btn = gr.Button(value="Send", variant="primary")
257
- with gr.Row(elem_id="buttons") as button_row:
258
- upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
259
- downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
260
- flag_btn = gr.Button(value="⚠️ Flag", interactive=False)
261
- #stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=False)
262
- regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=False)
263
- clear_btn = gr.Button(value="🗑️ Clear", interactive=False)
264
-
265
- gr.Markdown(tos_markdown)
266
- gr.Markdown(learn_more_markdown)
267
- url_params = gr.JSON(visible=False)
268
-
269
- # Register listeners
270
- btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
271
- upvote_btn.click(
272
- upvote_last_response,
273
- [state],
274
- [textbox, upvote_btn, downvote_btn, flag_btn]
275
- )
276
- downvote_btn.click(
277
- downvote_last_response,
278
- [state],
279
- [textbox, upvote_btn, downvote_btn, flag_btn]
280
- )
281
- flag_btn.click(
282
- flag_last_response,
283
- [state],
284
- [textbox, upvote_btn, downvote_btn, flag_btn]
285
- )
286
-
287
- clear_btn.click(
288
- clear_history,
289
- None,
290
- [state, chatbot, textbox, imagebox] + btn_list,
291
- queue=False
292
- )
293
-
294
- regenerate_btn.click(
295
- delete_text,
296
- [state, image_process_mode],
297
- [state, chatbot, textbox, imagebox] + btn_list,
298
- ).then(
299
- generate,
300
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
301
- [state, chatbot, textbox, imagebox] + btn_list,
302
- )
303
- textbox.submit(
304
- add_text,
305
- [state, imagebox, textbox, image_process_mode],
306
- [state, chatbot, textbox, imagebox] + btn_list,
307
- ).then(
308
- generate,
309
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
310
- [state, chatbot, textbox, imagebox] + btn_list,
311
- )
312
-
313
- submit_btn.click(
314
- add_text,
315
- [state, imagebox, textbox, image_process_mode],
316
- [state, chatbot, textbox, imagebox] + btn_list,
317
- ).then(
318
- generate,
319
- [state, imagebox, textbox, image_process_mode, temperature, top_p, max_output_tokens],
320
- [state, chatbot, textbox, imagebox] + btn_list,
321
  )
322
 
323
- demo.queue(
324
- status_update_rate=10,
325
- api_open=False
326
- ).launch()
 
1
+ # gradio_blip3o_next_min.py
2
+ import time
3
+ from dataclasses import dataclass
 
 
 
 
 
 
 
 
4
 
5
+ import torch
6
  from PIL import Image
7
+ from transformers import AutoTokenizer
8
+ from blip3o.model import *
9
+ import gradio as gr
10
+ from huggingface_hub import snapshot_download
11
+
12
+
13
+ # -----------------------------
14
+ # Minimal config and runner
15
+ # -----------------------------
16
+ @dataclass
17
+ class T2IConfig:
18
+ device: str = "cuda:0"
19
+ dtype: torch.dtype = torch.bfloat16
20
+ # fixed generation config (no UI controls)
21
+ scale: int = 0
22
+ seq_len: int = 729
23
+ top_p: float = 0.95
24
+ top_k: int = 1200
25
+
26
+
27
+ class TextToImageInference:
28
+ def __init__(self, config: T2IConfig):
29
+ self.config = config
30
+ self.device = torch.device(config.device)
31
+ self._load_models()
32
+
33
+ def _load_models(self):
34
+ model_path = snapshot_download(repo_id='BLIP3o/BLIP3o-NEXT-GRPO-Geneval-3B')
35
+ self.model = blip3oQwenForInferenceLM.from_pretrained(
36
+ model_path, torch_dtype=self.config.dtype
37
+ ).to(self.device)
38
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
39
+ if hasattr(self.tokenizer, "padding_side"):
40
+ self.tokenizer.padding_side = "left"
41
+
42
+ @torch.inference_mode()
43
+ def generate_image(self, prompt: str) -> Image.Image:
44
+ messages = [
45
+ {"role": "system", "content": "You are a helpful assistant."},
46
+ {
47
+ "role": "user",
48
+ "content": f"Please generate image based on the following caption: {prompt}",
49
+ },
50
+ ]
51
+ input_text = self.tokenizer.apply_chat_template(
52
+ messages, tokenize=False, add_generation_prompt=True
53
+ )
54
+ input_text += f"<im_start><S{self.config.scale}>"
55
+
56
+ inputs = self.tokenizer(
57
+ [input_text], return_tensors="pt", padding=True, truncation=True
58
+ )
59
+
60
+ _, images = self.model.generate_images(
61
+ inputs.input_ids.to(self.device),
62
+ inputs.attention_mask.to(self.device),
63
+ max_new_tokens=self.config.seq_len,
64
+ do_sample=True,
65
+ top_p=self.config.top_p,
66
+ top_k=self.config.top_k,
67
+ )
68
+ return images[0]
69
+
70
+
71
+ # Try loading once at startup for simplicity
72
+ LOAD_ERROR = None
73
+ inference = None
74
+ try:
75
+ inference = TextToImageInference(T2IConfig())
76
+ except Exception as e:
77
+ LOAD_ERROR = f"❌ Failed to load model: {e}"
78
+
79
+
80
+ def run_generate(prompt, progress=gr.Progress(track_tqdm=True)):
81
+ t0 = time.time()
82
+ if LOAD_ERROR:
83
+ return None, LOAD_ERROR
84
+ if not prompt or not prompt.strip():
85
+ return None, "⚠️ Please enter a prompt."
86
+
87
+ try:
88
+ img = inference.generate_image(prompt.strip())
89
+ return img, f" Done in {time.time() - t0:.2f}s."
90
+ except torch.cuda.OutOfMemoryError:
91
+ if torch.cuda.is_available():
92
+ torch.cuda.empty_cache()
93
+ return None, "❌ CUDA OOM. Try reducing other GPU workloads."
94
+ except Exception as e:
95
+ return None, f"❌ Error: {e}"
96
+
97
+
98
+ with gr.Blocks(title="BLIP3o-NEXT-GRPO-Geneval — Text ➜ Image") as demo:
99
+ gr.Markdown("# BLIP3o-NEXT-GRPO-Geneval — Text ➜ Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  with gr.Row():
102
  with gr.Column(scale=3):
103
+ prompt = gr.Textbox(
104
+ label="Prompt",
105
+ placeholder="Describe the image you want to generate...",
106
+ lines=4,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  )
108
+ run_btn = gr.Button("Generate", variant="primary")
109
+
110
+ with gr.Column(scale=4):
111
+ out_img = gr.Image(label="Generated Image", format="png")
112
+ status = gr.Markdown("")
113
+
114
+ run_btn.click(
115
+ fn=run_generate,
116
+ inputs=[prompt],
117
+ outputs=[out_img, status],
118
+ queue=True,
119
+ api_name="generate",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  )
121
 
122
+ if __name__ == "__main__":
123
+ demo.queue().launch(share=True)