|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
import io |
|
|
import requests |
|
|
from PIL import Image |
|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
MODEL_ID = "apple/FastVLM-0.5B" |
|
|
IMAGE_TOKEN_INDEX = -200 |
|
|
DEVICE = "cpu" |
|
|
|
|
|
|
|
|
SAMPLES = { |
|
|
|
|
|
"Dog-in-street (COCO)": "http://images.cocodataset.org/val2017/000000039769.jpg", |
|
|
|
|
|
"Chart β Blind wine tasting": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/wine_blind_taste.png", |
|
|
"Chart β Life expectancy (Africa vs Asia)": "https://huggingface.co/datasets/lytang/ChartMuseum/resolve/main/images/life-expectancy-africa-vs-asia.png", |
|
|
|
|
|
"Document page β example": "https://huggingface.co/datasets/hf-internal-testing/example-documents/resolve/main/jpeg_images/1.jpg", |
|
|
} |
|
|
|
|
|
TASK_PROMPTS = { |
|
|
"Explain": "Describe this image in detail.", |
|
|
"Extract numbers": ( |
|
|
"Extract every number you can see with its label/context. " |
|
|
"Return a concise YAML list with fields: value, what_it_refers_to." |
|
|
), |
|
|
"Write alt-text": ( |
|
|
"Write high-quality alt-text (<=200 chars) that would help a blind user understand " |
|
|
"the key content and purpose of this image." |
|
|
), |
|
|
"Ask a questionβ¦": None, |
|
|
} |
|
|
|
|
|
|
|
|
tok = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
torch_dtype=torch.float32, |
|
|
device_map={"": DEVICE}, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
|
|
|
|
|
|
def _fetch_image(url: str) -> Image.Image: |
|
|
r = requests.get(url, timeout=20) |
|
|
r.raise_for_status() |
|
|
return Image.open(io.BytesIO(r.content)).convert("RGB") |
|
|
|
|
|
def _build_inputs(prompt: str): |
|
|
|
|
|
messages = [{"role": "user", "content": f"<image>\n{prompt}"}] |
|
|
rendered = tok.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) |
|
|
pre, post = rendered.split("<image>", 1) |
|
|
|
|
|
pre_ids = tok(pre, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
post_ids = tok(post, return_tensors="pt", add_special_tokens=False).input_ids |
|
|
img_tok = torch.tensor([[IMAGE_TOKEN_INDEX]], dtype=pre_ids.dtype) |
|
|
|
|
|
input_ids = torch.cat([pre_ids, img_tok, post_ids], dim=1).to(model.device) |
|
|
attention_mask = torch.ones_like(input_ids, device=model.device) |
|
|
return input_ids, attention_mask |
|
|
|
|
|
def _prepare_pixels(pil_image: Image.Image): |
|
|
|
|
|
px = model.get_vision_tower().image_processor(images=pil_image, return_tensors="pt")["pixel_values"] |
|
|
return px.to(model.device, dtype=model.dtype) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def run_inference(choice: str, task: str, user_q: str, max_new_tokens: int, temperature: float): |
|
|
try: |
|
|
img = _fetch_image(SAMPLES[choice]) |
|
|
except Exception as e: |
|
|
return None, f"Could not load image: {e}", "" |
|
|
|
|
|
|
|
|
if task == "Ask a questionβ¦": |
|
|
prompt = user_q.strip() or "Answer questions about this image." |
|
|
else: |
|
|
prompt = TASK_PROMPTS[task] |
|
|
|
|
|
|
|
|
input_ids, attention_mask = _build_inputs(prompt) |
|
|
px = _prepare_pixels(img) |
|
|
|
|
|
|
|
|
t0 = time.perf_counter() |
|
|
out = model.generate( |
|
|
inputs=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
images=px, |
|
|
max_new_tokens=int(max_new_tokens), |
|
|
temperature=float(temperature), |
|
|
) |
|
|
t1 = time.perf_counter() |
|
|
|
|
|
text = tok.decode(out[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
gen_len = (out.shape[-1] - input_ids.shape[-1]) |
|
|
elapsed = t1 - t0 |
|
|
meta = f"β±οΈ {elapsed:.2f}s β’ new tokens: {gen_len} β’ ~{(gen_len/elapsed if elapsed>0 else 0):.1f} tok/s β’ device: {DEVICE.upper()}" |
|
|
|
|
|
return img, text.strip(), meta |
|
|
|
|
|
|
|
|
with gr.Blocks(title="FastVLM Screenshot Explainer (CPU)") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# β‘ FastVLM Screenshot Explainer β CPU-only (no uploads) |
|
|
Click an example image, pick a task, and go. |
|
|
Model: **apple/FastVLM-0.5B** (research license). |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
choice = gr.Dropdown( |
|
|
label="Choose example image", |
|
|
choices=list(SAMPLES.keys()), |
|
|
value=list(SAMPLES.keys())[0], |
|
|
) |
|
|
task = gr.Radio( |
|
|
label="Task", |
|
|
choices=list(TASK_PROMPTS.keys()), |
|
|
value="Explain", |
|
|
info="βAsk a questionβ¦β enables free-form VQA.", |
|
|
) |
|
|
user_q = gr.Textbox(label="If asking a question, type it here", placeholder="e.g., What is the trend from 1950 to 2000?") |
|
|
with gr.Accordion("Generation settings", open=False): |
|
|
max_new = gr.Slider(32, 256, 128, step=8, label="max_new_tokens") |
|
|
temp = gr.Slider(0.0, 1.0, 0.2, step=0.05, label="temperature") |
|
|
|
|
|
go = gr.Button("Explain / Answer", variant="primary") |
|
|
with gr.Row(): |
|
|
img_out = gr.Image(label="Image", interactive=False) |
|
|
txt_out = gr.Textbox(label="Model output", lines=14) |
|
|
meta = gr.Markdown() |
|
|
|
|
|
go.click(run_inference, [choice, task, user_q, max_new, temp], [img_out, txt_out, meta]) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
**Notes** |
|
|
- Runs on CPU by default (float32). For GPUs, restart Space with CUDA and it will auto-use float16. |
|
|
- Model + usage based on the official model cardβs `trust_remote_code` API and <image> token handling. |
|
|
- **License:** Apple AML Research License β *research & non-commercial use only*. |
|
|
""" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|