IFMedTechdemo commited on
Commit
da048ad
·
verified ·
1 Parent(s): c0fa0f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -17
app.py CHANGED
@@ -1,4 +1,3 @@
1
-
2
  import os
3
  import time
4
  import torch
@@ -18,13 +17,18 @@ from transformers import (
18
  Qwen2_5_VLForConditionalGeneration,
19
  TextIteratorStreamer
20
  )
 
21
  from qwen_vl_utils import process_vision_info
22
 
23
 
 
 
24
  # Suppress the warning about uninitialized weights
25
  warnings.filterwarnings('ignore', message='Some weights.*were not initialized')
26
 
27
 
 
 
28
  # Try importing Qwen3VL if available
29
  try:
30
  from transformers import Qwen3VLForConditionalGeneration
@@ -32,18 +36,27 @@ except ImportError:
32
  Qwen3VLForConditionalGeneration = None
33
 
34
 
 
 
35
  MAX_MAX_NEW_TOKENS = 4096
36
  DEFAULT_MAX_NEW_TOKENS = 2048
37
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
 
 
 
38
 
39
 
40
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
41
 
42
 
 
 
43
  print(f"Initial Device: {device}")
44
  print(f"CUDA Available: {torch.cuda.is_available()}")
45
 
46
 
 
 
47
  # Load Chandra-OCR
48
  try:
49
  MODEL_ID_V = "datalab-to/chandra"
@@ -52,7 +65,8 @@ try:
52
  model_v = Qwen3VLForConditionalGeneration.from_pretrained(
53
  MODEL_ID_V,
54
  trust_remote_code=True,
55
- torch_dtype=torch.float16
 
56
  ).eval()
57
  print("✓ Chandra-OCR loaded")
58
  else:
@@ -64,6 +78,8 @@ except Exception as e:
64
  print(f"✗ Chandra-OCR: Failed to load - {str(e)}")
65
 
66
 
 
 
67
  # Load Nanonets-OCR2-3B
68
  try:
69
  MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
@@ -71,7 +87,8 @@ try:
71
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
72
  MODEL_ID_X,
73
  trust_remote_code=True,
74
- torch_dtype=torch.float16
 
75
  ).eval()
76
  print("✓ Nanonets-OCR2-3B loaded")
77
  except Exception as e:
@@ -80,14 +97,31 @@ except Exception as e:
80
  print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}")
81
 
82
 
83
- # Load Dots.OCR - will be moved to GPU when needed
 
 
84
  try:
85
- MODEL_PATH_D = "strangervisionhf/dots.ocr-base-fix"
86
- processor_d = AutoProcessor.from_pretrained(MODEL_PATH_D, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  model_d = AutoModelForCausalLM.from_pretrained(
88
- MODEL_PATH_D,
89
  attn_implementation="flash_attention_2",
90
  torch_dtype=torch.bfloat16,
 
91
  trust_remote_code=True
92
  ).eval()
93
  print("✓ Dots.OCR loaded")
@@ -95,6 +129,10 @@ except Exception as e:
95
  model_d = None
96
  processor_d = None
97
  print(f"✗ Dots.OCR: Failed to load - {str(e)}")
 
 
 
 
98
 
99
 
100
  # Load olmOCR-2-7B-1025
@@ -104,7 +142,8 @@ try:
104
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
105
  MODEL_ID_M,
106
  trust_remote_code=True,
107
- torch_dtype=torch.float16
 
108
  ).eval()
109
  print("✓ olmOCR-2-7B-1025 loaded")
110
  except Exception as e:
@@ -113,6 +152,8 @@ except Exception as e:
113
  print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")
114
 
115
 
 
 
116
  @spaces.GPU
117
  def generate_image(model_name: str, text: str, image: Image.Image,
118
  max_new_tokens: int, temperature: float, top_p: float,
@@ -120,10 +161,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
120
  """
121
  Generates responses using the selected model for image input.
122
  Yields raw text and Markdown-formatted text.
123
-
124
  This function is decorated with @spaces.GPU to ensure it runs on GPU
125
  when available in Hugging Face Spaces.
126
-
127
  Args:
128
  model_name: Name of the OCR model to use
129
  text: Prompt text for the model
@@ -133,48 +172,52 @@ def generate_image(model_name: str, text: str, image: Image.Image,
133
  top_p: Nucleus sampling parameter
134
  top_k: Top-k sampling parameter
135
  repetition_penalty: Penalty for repeating tokens
136
-
137
  Yields:
138
  tuple: (raw_text, markdown_text)
139
  """
140
  # Device will be cuda when @spaces.GPU decorator activates
141
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
142
 
 
143
  # Select model and processor based on model_name
144
  if model_name == "olmOCR-2-7B-1025":
145
  if model_m is None:
146
  yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available."
147
  return
148
  processor = processor_m
149
- model = model_m.to(device)
150
  elif model_name == "Nanonets-OCR2-3B":
151
  if model_x is None:
152
  yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available."
153
  return
154
  processor = processor_x
155
- model = model_x.to(device)
156
  elif model_name == "Chandra-OCR":
157
  if model_v is None:
158
  yield "Chandra-OCR is not available.", "Chandra-OCR is not available."
159
  return
160
  processor = processor_v
161
- model = model_v.to(device)
162
  elif model_name == "Dots.OCR":
163
  if model_d is None:
164
  yield "Dots.OCR is not available.", "Dots.OCR is not available."
165
  return
166
  processor = processor_d
167
- model = model_d.to(device)
168
  else:
169
  yield "Invalid model selected.", "Invalid model selected."
170
  return
171
 
172
 
 
 
173
  if image is None:
174
  yield "Please upload an image.", "Please upload an image."
175
  return
176
 
177
 
 
 
178
  try:
179
  # Prepare messages in chat format
180
  messages = [{
@@ -185,6 +228,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
185
  ]
186
  }]
187
 
 
188
  # Apply chat template with fallback
189
  try:
190
  prompt_full = processor.apply_chat_template(
@@ -198,6 +242,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
198
  prompt_full = f"{text}"
199
 
200
 
 
 
201
  # Process inputs
202
  inputs = processor(
203
  text=[prompt_full],
@@ -207,6 +253,8 @@ def generate_image(model_name: str, text: str, image: Image.Image,
207
  ).to(device)
208
 
209
 
 
 
210
  # Setup streaming generation
211
  streamer = TextIteratorStreamer(
212
  processor.tokenizer if hasattr(processor, 'tokenizer') else processor,
@@ -214,6 +262,7 @@ def generate_image(model_name: str, text: str, image: Image.Image,
214
  skip_special_tokens=True
215
  )
216
 
 
217
  generation_kwargs = {
218
  **inputs,
219
  "streamer": streamer,
@@ -225,10 +274,12 @@ def generate_image(model_name: str, text: str, image: Image.Image,
225
  "repetition_penalty": repetition_penalty,
226
  }
227
 
 
228
  # Start generation in separate thread
229
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
230
  thread.start()
231
 
 
232
  # Stream the results
233
  buffer = ""
234
  for new_text in streamer:
@@ -237,9 +288,11 @@ def generate_image(model_name: str, text: str, image: Image.Image,
237
  time.sleep(0.01)
238
  yield buffer, buffer
239
 
 
240
  # Ensure thread completes
241
  thread.join()
242
 
 
243
  except Exception as e:
244
  error_msg = f"Error during generation: {str(e)}"
245
  print(f"Full error: {e}")
@@ -248,10 +301,13 @@ def generate_image(model_name: str, text: str, image: Image.Image,
248
  yield error_msg, error_msg
249
 
250
 
 
 
251
  # Example usage for Gradio interface
252
  if __name__ == "__main__":
253
  import gradio as gr
254
 
 
255
  # Determine available models
256
  available_models = []
257
  if model_m is not None:
@@ -267,16 +323,20 @@ if __name__ == "__main__":
267
  available_models.append("Dots.OCR")
268
  print(" Added: Dots.OCR")
269
 
 
270
  if not available_models:
271
  print("ERROR: No models were loaded successfully!")
272
  exit(1)
273
 
 
274
  print(f"\n✓ Available models for dropdown: {', '.join(available_models)}")
275
 
 
276
  with gr.Blocks(title="Multi-Model OCR") as demo:
277
  gr.Markdown("# 🔍 Multi-Model OCR Application")
278
  gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.")
279
 
 
280
  with gr.Row():
281
  with gr.Column():
282
  model_selector = gr.Dropdown(
@@ -291,6 +351,7 @@ if __name__ == "__main__":
291
  lines=2
292
  )
293
 
 
294
  with gr.Accordion("Advanced Settings", open=False):
295
  max_tokens = gr.Slider(
296
  minimum=1,
@@ -328,20 +389,24 @@ if __name__ == "__main__":
328
  label="Repetition Penalty"
329
  )
330
 
 
331
  submit_btn = gr.Button("Extract Text", variant="primary")
332
 
 
333
  with gr.Column():
334
  output_text = gr.Textbox(label="Extracted Text", lines=20)
335
  output_markdown = gr.Markdown(label="Formatted Output")
336
 
 
337
  gr.Markdown("""
338
  ### Available Models:
339
  - **olmOCR-2-7B-1025**: Allen AI's OCR model
340
  - **Nanonets-OCR2-3B**: Nanonets OCR model
341
  - **Chandra-OCR**: Datalab OCR model
342
- - **Dots.OCR**: Stranger Vision OCR model
343
  """)
344
 
 
345
  submit_btn.click(
346
  fn=generate_image,
347
  inputs=[
@@ -357,5 +422,6 @@ if __name__ == "__main__":
357
  outputs=[output_text, output_markdown]
358
  )
359
 
 
360
  # Launch with share=True for Hugging Face Spaces
361
- demo.launch(share=True)
 
 
1
  import os
2
  import time
3
  import torch
 
17
  Qwen2_5_VLForConditionalGeneration,
18
  TextIteratorStreamer
19
  )
20
+ from huggingface_hub import snapshot_download
21
  from qwen_vl_utils import process_vision_info
22
 
23
 
24
+
25
+
26
  # Suppress the warning about uninitialized weights
27
  warnings.filterwarnings('ignore', message='Some weights.*were not initialized')
28
 
29
 
30
+
31
+
32
  # Try importing Qwen3VL if available
33
  try:
34
  from transformers import Qwen3VLForConditionalGeneration
 
36
  Qwen3VLForConditionalGeneration = None
37
 
38
 
39
+
40
+
41
  MAX_MAX_NEW_TOKENS = 4096
42
  DEFAULT_MAX_NEW_TOKENS = 2048
43
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
44
+ CACHE_DIR = os.getenv("HF_CACHE_DIR", "./models")
45
+
46
+
47
 
48
 
49
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
50
 
51
 
52
+
53
+
54
  print(f"Initial Device: {device}")
55
  print(f"CUDA Available: {torch.cuda.is_available()}")
56
 
57
 
58
+
59
+
60
  # Load Chandra-OCR
61
  try:
62
  MODEL_ID_V = "datalab-to/chandra"
 
65
  model_v = Qwen3VLForConditionalGeneration.from_pretrained(
66
  MODEL_ID_V,
67
  trust_remote_code=True,
68
+ torch_dtype=torch.float16,
69
+ device_map="auto"
70
  ).eval()
71
  print("✓ Chandra-OCR loaded")
72
  else:
 
78
  print(f"✗ Chandra-OCR: Failed to load - {str(e)}")
79
 
80
 
81
+
82
+
83
  # Load Nanonets-OCR2-3B
84
  try:
85
  MODEL_ID_X = "nanonets/Nanonets-OCR2-3B"
 
87
  model_x = Qwen2_5_VLForConditionalGeneration.from_pretrained(
88
  MODEL_ID_X,
89
  trust_remote_code=True,
90
+ torch_dtype=torch.float16,
91
+ device_map="auto"
92
  ).eval()
93
  print("✓ Nanonets-OCR2-3B loaded")
94
  except Exception as e:
 
97
  print(f"✗ Nanonets-OCR2-3B: Failed to load - {str(e)}")
98
 
99
 
100
+
101
+
102
+ # Load Dots.OCR - UPDATED with snapshot_download and device_map="auto"
103
  try:
104
+ MODEL_ID_D = "rednote-hilab/dots.ocr"
105
+ model_path_d = os.path.join(CACHE_DIR, "dots-ocr-local")
106
+
107
+ # Download and cache model locally
108
+ snapshot_download(
109
+ repo_id=MODEL_ID_D,
110
+ local_dir=model_path_d,
111
+ local_dir_use_symlinks=False, # Avoid symlink issues on HF Spaces
112
+ allow_patterns=["*.json", "*.bin", "*.safetensors", "*.txt"]
113
+ )
114
+
115
+ processor_d = AutoProcessor.from_pretrained(
116
+ model_path_d,
117
+ trust_remote_code=True
118
+ )
119
+
120
  model_d = AutoModelForCausalLM.from_pretrained(
121
+ model_path_d,
122
  attn_implementation="flash_attention_2",
123
  torch_dtype=torch.bfloat16,
124
+ device_map="auto", # Better memory management
125
  trust_remote_code=True
126
  ).eval()
127
  print("✓ Dots.OCR loaded")
 
129
  model_d = None
130
  processor_d = None
131
  print(f"✗ Dots.OCR: Failed to load - {str(e)}")
132
+ import traceback
133
+ traceback.print_exc()
134
+
135
+
136
 
137
 
138
  # Load olmOCR-2-7B-1025
 
142
  model_m = Qwen2_5_VLForConditionalGeneration.from_pretrained(
143
  MODEL_ID_M,
144
  trust_remote_code=True,
145
+ torch_dtype=torch.float16,
146
+ device_map="auto"
147
  ).eval()
148
  print("✓ olmOCR-2-7B-1025 loaded")
149
  except Exception as e:
 
152
  print(f"✗ olmOCR-2-7B-1025: Failed to load - {str(e)}")
153
 
154
 
155
+
156
+
157
  @spaces.GPU
158
  def generate_image(model_name: str, text: str, image: Image.Image,
159
  max_new_tokens: int, temperature: float, top_p: float,
 
161
  """
162
  Generates responses using the selected model for image input.
163
  Yields raw text and Markdown-formatted text.
 
164
  This function is decorated with @spaces.GPU to ensure it runs on GPU
165
  when available in Hugging Face Spaces.
 
166
  Args:
167
  model_name: Name of the OCR model to use
168
  text: Prompt text for the model
 
172
  top_p: Nucleus sampling parameter
173
  top_k: Top-k sampling parameter
174
  repetition_penalty: Penalty for repeating tokens
 
175
  Yields:
176
  tuple: (raw_text, markdown_text)
177
  """
178
  # Device will be cuda when @spaces.GPU decorator activates
179
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
180
 
181
+
182
  # Select model and processor based on model_name
183
  if model_name == "olmOCR-2-7B-1025":
184
  if model_m is None:
185
  yield "olmOCR-2-7B-1025 is not available.", "olmOCR-2-7B-1025 is not available."
186
  return
187
  processor = processor_m
188
+ model = model_m
189
  elif model_name == "Nanonets-OCR2-3B":
190
  if model_x is None:
191
  yield "Nanonets-OCR2-3B is not available.", "Nanonets-OCR2-3B is not available."
192
  return
193
  processor = processor_x
194
+ model = model_x
195
  elif model_name == "Chandra-OCR":
196
  if model_v is None:
197
  yield "Chandra-OCR is not available.", "Chandra-OCR is not available."
198
  return
199
  processor = processor_v
200
+ model = model_v
201
  elif model_name == "Dots.OCR":
202
  if model_d is None:
203
  yield "Dots.OCR is not available.", "Dots.OCR is not available."
204
  return
205
  processor = processor_d
206
+ model = model_d
207
  else:
208
  yield "Invalid model selected.", "Invalid model selected."
209
  return
210
 
211
 
212
+
213
+
214
  if image is None:
215
  yield "Please upload an image.", "Please upload an image."
216
  return
217
 
218
 
219
+
220
+
221
  try:
222
  # Prepare messages in chat format
223
  messages = [{
 
228
  ]
229
  }]
230
 
231
+
232
  # Apply chat template with fallback
233
  try:
234
  prompt_full = processor.apply_chat_template(
 
242
  prompt_full = f"{text}"
243
 
244
 
245
+
246
+
247
  # Process inputs
248
  inputs = processor(
249
  text=[prompt_full],
 
253
  ).to(device)
254
 
255
 
256
+
257
+
258
  # Setup streaming generation
259
  streamer = TextIteratorStreamer(
260
  processor.tokenizer if hasattr(processor, 'tokenizer') else processor,
 
262
  skip_special_tokens=True
263
  )
264
 
265
+
266
  generation_kwargs = {
267
  **inputs,
268
  "streamer": streamer,
 
274
  "repetition_penalty": repetition_penalty,
275
  }
276
 
277
+
278
  # Start generation in separate thread
279
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
280
  thread.start()
281
 
282
+
283
  # Stream the results
284
  buffer = ""
285
  for new_text in streamer:
 
288
  time.sleep(0.01)
289
  yield buffer, buffer
290
 
291
+
292
  # Ensure thread completes
293
  thread.join()
294
 
295
+
296
  except Exception as e:
297
  error_msg = f"Error during generation: {str(e)}"
298
  print(f"Full error: {e}")
 
301
  yield error_msg, error_msg
302
 
303
 
304
+
305
+
306
  # Example usage for Gradio interface
307
  if __name__ == "__main__":
308
  import gradio as gr
309
 
310
+
311
  # Determine available models
312
  available_models = []
313
  if model_m is not None:
 
323
  available_models.append("Dots.OCR")
324
  print(" Added: Dots.OCR")
325
 
326
+
327
  if not available_models:
328
  print("ERROR: No models were loaded successfully!")
329
  exit(1)
330
 
331
+
332
  print(f"\n✓ Available models for dropdown: {', '.join(available_models)}")
333
 
334
+
335
  with gr.Blocks(title="Multi-Model OCR") as demo:
336
  gr.Markdown("# 🔍 Multi-Model OCR Application")
337
  gr.Markdown("Upload an image and select a model to extract text. Models run on GPU via Hugging Face Spaces.")
338
 
339
+
340
  with gr.Row():
341
  with gr.Column():
342
  model_selector = gr.Dropdown(
 
351
  lines=2
352
  )
353
 
354
+
355
  with gr.Accordion("Advanced Settings", open=False):
356
  max_tokens = gr.Slider(
357
  minimum=1,
 
389
  label="Repetition Penalty"
390
  )
391
 
392
+
393
  submit_btn = gr.Button("Extract Text", variant="primary")
394
 
395
+
396
  with gr.Column():
397
  output_text = gr.Textbox(label="Extracted Text", lines=20)
398
  output_markdown = gr.Markdown(label="Formatted Output")
399
 
400
+
401
  gr.Markdown("""
402
  ### Available Models:
403
  - **olmOCR-2-7B-1025**: Allen AI's OCR model
404
  - **Nanonets-OCR2-3B**: Nanonets OCR model
405
  - **Chandra-OCR**: Datalab OCR model
406
+ - **Dots.OCR**: Stranger Vision OCR model (Updated)
407
  """)
408
 
409
+
410
  submit_btn.click(
411
  fn=generate_image,
412
  inputs=[
 
422
  outputs=[output_text, output_markdown]
423
  )
424
 
425
+
426
  # Launch with share=True for Hugging Face Spaces
427
+ demo.launch(share=True)