adarsh8962 commited on
Commit
2d6c974
·
verified ·
1 Parent(s): f838b85

Update llava/serve/gradio_utils.py

Browse files
Files changed (1) hide show
  1. llava/serve/gradio_utils.py +52 -23
llava/serve/gradio_utils.py CHANGED
@@ -7,23 +7,24 @@ from llava.model.builder import load_pretrained_model
7
  from llava.utils import disable_torch_init
8
 
9
 
 
10
  import re
11
  import torch
12
 
13
- # ---------- Stable generation defaults (stop bracket loops) ----------
14
  GEN_KW = dict(
15
- do_sample=False, # deterministic
16
  temperature=0.0,
17
  top_p=1.0,
18
- repetition_penalty=1.15, # break single-token loops like [[[[[
19
- no_repeat_ngram_size=3, # avoid short repeats
20
- use_cache=False, # lower VRAM on L4; fine on L40S too
21
  )
22
 
23
  def _big_gpu():
24
  try:
25
  return (torch.cuda.is_available()
26
- and torch.cuda.get_device_properties(0).total_memory / 1024**3 >= 40)
27
  except Exception:
28
  return False
29
 
@@ -41,21 +42,22 @@ def build_framewise_prompt(T: int) -> str:
41
  )
42
 
43
  def keep_frame_lines(text: str, T: int) -> str:
44
- """Keep only 'Frame i: ...' lines; ensure frames 1..T exist."""
45
  lines = []
46
  for ln in text.splitlines():
47
- m = re.match(r"^Frame\s+(\d+)\s*:\s*(.+)$", ln.strip())
48
  if not m:
49
  continue
50
  i = int(m.group(1))
51
- body = " ".join(m.group(2).split()[:10]) # ≤10 words
52
  if 1 <= i <= T:
53
- lines.append((i, f"Frame {i}: {body}"))
54
  have = {i for i,_ in lines}
55
  for i in range(1, T+1):
56
  if i not in have:
57
- lines.append((i, f"Frame {i}: (no description)"))
58
- return "\n".join(t for _, t in sorted(lines))
 
59
 
60
 
61
  title_markdown = ("""
@@ -168,26 +170,53 @@ class Chat:
168
  # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
169
  # print(input_ids, images_tensor[0][0].shape)
170
  with torch.inference_mode():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  output_ids = model.generate(
172
  input_ids,
173
  images=images_tensor,
174
- do_sample=True,
175
- temperature=temperature,
176
  max_new_tokens=max_new_tokens,
177
- # streamer=streamer,
178
- use_cache=True,
179
- stopping_criteria=[stopping_criteria])
180
-
 
181
  input_token_len = input_ids.shape[1]
182
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
183
  if n_diff_input_output > 0:
184
  print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
185
  outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
186
  outputs = outputs.strip()
187
- if outputs.endswith(stop_str):
188
- outputs = outputs[:-len(stop_str)]
189
- outputs = outputs.strip()
190
-
191
- print('response', outputs)
 
 
 
 
 
 
192
  return outputs, state
 
 
193
 
 
7
  from llava.utils import disable_torch_init
8
 
9
 
10
+ # ==== memory-safe, de-hallucinating generation helpers ====
11
  import re
12
  import torch
13
 
14
+ # deterministic + anti-repeat defaults
15
  GEN_KW = dict(
16
+ do_sample=False,
17
  temperature=0.0,
18
  top_p=1.0,
19
+ repetition_penalty=1.15, # breaks [[[ spam
20
+ no_repeat_ngram_size=3, # avoids short loops
21
+ use_cache=False, # reduces VRAM spikes on L4
22
  )
23
 
24
  def _big_gpu():
25
  try:
26
  return (torch.cuda.is_available()
27
+ and torch.cuda.get_device_properties(0).total_memory / 1024**3 >= 40) # >=40GB = L40S/A100
28
  except Exception:
29
  return False
30
 
 
42
  )
43
 
44
  def keep_frame_lines(text: str, T: int) -> str:
45
+ \"\"\"Keep only `Frame i: ...` lines; ensure frames 1..T exist.\"\"\"
46
  lines = []
47
  for ln in text.splitlines():
48
+ m = re.match(r\"^Frame\\s+(\\d+)\\s*:\\s*(.+)$\", ln.strip())
49
  if not m:
50
  continue
51
  i = int(m.group(1))
52
+ body = \" \".join(m.group(2).split()[:10]) # ≤10 words
53
  if 1 <= i <= T:
54
+ lines.append((i, f\"Frame {i}: {body}\"))
55
  have = {i for i,_ in lines}
56
  for i in range(1, T+1):
57
  if i not in have:
58
+ lines.append((i, f\"Frame {i}: (no description)\")) # never leaves gaps
59
+ return \"\\n\".join(t for _, t in sorted(lines))
60
+ # ==== end helpers ====
61
 
62
 
63
  title_markdown = ("""
 
170
  # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
171
  # print(input_ids, images_tensor[0][0].shape)
172
  with torch.inference_mode():
173
+ # infer how many frames actually went in (works for list-of-frames or tensors)
174
+ def _infer_T(imgs):
175
+ try:
176
+ if isinstance(imgs, (list, tuple)) and len(imgs) > 0:
177
+ first = imgs[0]
178
+ if isinstance(first, (list, tuple)):
179
+ return len(first)
180
+ if hasattr(first, "shape"):
181
+ shp = list(first.shape)
182
+ if len(shp) >= 4: # [T, C, H, W] or [1, T, C, H, W]
183
+ return int(shp[0])
184
+ except Exception:
185
+ pass
186
+ return 8 # safe default
187
+
188
+ _T = _infer_T(images_tensor)
189
+
190
+ # VRAM-aware cap: more frames → allow a few more tokens, but stay safe on L4
191
+ max_new_tokens = min(16 * max(1, _T), MAX_NEW_TOKENS_BIG if _big_gpu() else MAX_NEW_TOKENS_SMALL)
192
+
193
  output_ids = model.generate(
194
  input_ids,
195
  images=images_tensor,
 
 
196
  max_new_tokens=max_new_tokens,
197
+ **GEN_KW, # <- deterministic + lower VRAM
198
+ stopping_criteria=[stopping_criteria],
199
+ )
200
+
201
+
202
  input_token_len = input_ids.shape[1]
203
  n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
204
  if n_diff_input_output > 0:
205
  print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
206
  outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
207
  outputs = outputs.strip()
208
+ # If user asked about frames, force a clean "Frame i: ..." list
209
+ try:
210
+ _T = _infer_T(images_tensor)
211
+ except Exception:
212
+ _T = 8
213
+ if "frame" in prompt.lower():
214
+ cleaned = keep_frame_lines(outputs, _T)
215
+ if cleaned.strip():
216
+ outputs = cleaned
217
+
218
+ print("response", outputs)
219
  return outputs, state
220
+
221
+
222