ICM / README.md
beingamanforever's picture
Update README.md
bb87f91 verified
metadata
language: en
license: mit
tags:
  - text-classification
  - nlp
  - transformers
  - bert
  - routing
  - vision-task-classifier
model_name: ICM
base_model: bert-base-uncased
pipeline_tag: text-classification
datasets:
  - synthetic
tasks:
  - text-classification
library_name: transformers

Task Classification Model (ICM)

Model Description

A BERT-based sequence classification model that routes computer vision questions to appropriate specialized modules. Classifies questions into 4 task categories: VQA, Captioning, Grounding, and Geometry.

  • Repository: beingamanforever/ICM
  • Base Model: bert-base-uncased
  • Task: 4-way Sequence Classification

Labels

ID Label Description
0 vqa Visual Question Answering ("What color is the car?")
1 captioning Image Description ("Describe the sunset.")
2 grounding Object Localization ("Find the person in the image.")
3 geometry Spatial/Metric Queries ("Calculate the area of the red box.")

Architecture

BERT-Base encoder + 3-layer MLP classifier on [CLS] token:

  • Layer 1: Linear(768 → 256) + ReLU + Dropout(0.1)
  • Layer 2: Linear(256 → 128) + ReLU + Dropout(0.1)
  • Layer 3: Linear(128 → 4)

Training

Hyperparameter Value
Samples 1,600 (400 per class)
Epochs 5
Learning Rate 2e-5
Batch Size 32
Optimizer AdamW
Loss Cross Entropy

Data: Synthetic questions from balanced JSON files (vqa_qs.json, captioning_qs.json, grounding_qs.json, geometry_qs.json)

Usage

from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch

model_name = "beingamanforever/ICM"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

questions = [
    "What is the distance between the two trees?",
    "Describe what the child is wearing.",
    "Is the traffic light green?",
    "Box the location of the blue umbrella."
]

inputs = tokenizer(questions, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
    logits = model(**inputs).logits
    predictions = torch.argmax(logits, dim=-1)

for q, pred in zip(questions, predictions):
    print(f"{q} → {model.config.id2label[pred.item()]}")

Limitations

  • Synthetic Training Data: May not generalize to complex real-world queries
  • Text-Only: Processes questions without image context
  • Domain Scope: Optimized for vision task routing, not general NLP classification

Intended Use

  • Automatic query routing in multimodal AI pipelines
  • VQA dataset analysis and taxonomy studies
  • Educational demonstrations of vision task classification