|
|
--- |
|
|
language: en |
|
|
license: mit |
|
|
tags: |
|
|
- text-classification |
|
|
- nlp |
|
|
- transformers |
|
|
- bert |
|
|
- routing |
|
|
- vision-task-classifier |
|
|
model_name: ICM |
|
|
base_model: bert-base-uncased |
|
|
pipeline_tag: text-classification |
|
|
datasets: |
|
|
- synthetic |
|
|
tasks: |
|
|
- text-classification |
|
|
library_name: transformers |
|
|
--- |
|
|
|
|
|
# Task Classification Model (ICM) |
|
|
|
|
|
## Model Description |
|
|
|
|
|
A BERT-based sequence classification model that routes computer vision questions to appropriate specialized modules. Classifies questions into 4 task categories: VQA, Captioning, Grounding, and Geometry. |
|
|
|
|
|
- **Repository:** beingamanforever/ICM |
|
|
- **Base Model:** bert-base-uncased |
|
|
- **Task:** 4-way Sequence Classification |
|
|
|
|
|
## Labels |
|
|
|
|
|
| ID | Label | Description | |
|
|
|---|---|---| |
|
|
| 0 | vqa | Visual Question Answering ("What color is the car?") | |
|
|
| 1 | captioning | Image Description ("Describe the sunset.") | |
|
|
| 2 | grounding | Object Localization ("Find the person in the image.") | |
|
|
| 3 | geometry | Spatial/Metric Queries ("Calculate the area of the red box.") | |
|
|
|
|
|
## Architecture |
|
|
|
|
|
BERT-Base encoder + 3-layer MLP classifier on [CLS] token: |
|
|
|
|
|
- Layer 1: Linear(768 β 256) + ReLU + Dropout(0.1) |
|
|
- Layer 2: Linear(256 β 128) + ReLU + Dropout(0.1) |
|
|
- Layer 3: Linear(128 β 4) |
|
|
|
|
|
## Training |
|
|
|
|
|
| Hyperparameter | Value | |
|
|
|---|---| |
|
|
| Samples | 1,600 (400 per class) | |
|
|
| Epochs | 5 | |
|
|
| Learning Rate | 2e-5 | |
|
|
| Batch Size | 32 | |
|
|
| Optimizer | AdamW | |
|
|
| Loss | Cross Entropy | |
|
|
|
|
|
**Data:** Synthetic questions from balanced JSON files (vqa_qs.json, captioning_qs.json, grounding_qs.json, geometry_qs.json) |
|
|
|
|
|
## Usage |
|
|
```python |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import torch |
|
|
|
|
|
model_name = "beingamanforever/ICM" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(model_name) |
|
|
|
|
|
questions = [ |
|
|
"What is the distance between the two trees?", |
|
|
"Describe what the child is wearing.", |
|
|
"Is the traffic light green?", |
|
|
"Box the location of the blue umbrella." |
|
|
] |
|
|
|
|
|
inputs = tokenizer(questions, return_tensors="pt", padding=True, truncation=True) |
|
|
with torch.no_grad(): |
|
|
logits = model(**inputs).logits |
|
|
predictions = torch.argmax(logits, dim=-1) |
|
|
|
|
|
for q, pred in zip(questions, predictions): |
|
|
print(f"{q} β {model.config.id2label[pred.item()]}") |
|
|
``` |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- **Synthetic Training Data:** May not generalize to complex real-world queries |
|
|
- **Text-Only:** Processes questions without image context |
|
|
- **Domain Scope:** Optimized for vision task routing, not general NLP classification |
|
|
|
|
|
## Intended Use |
|
|
|
|
|
- Automatic query routing in multimodal AI pipelines |
|
|
- VQA dataset analysis and taxonomy studies |
|
|
- Educational demonstrations of vision task classification |