| | --- |
| | language: en |
| | tags: |
| | - image-classification |
| | - image-captioning |
| |
|
| | --- |
| | |
| | # Poster2Plot |
| |
|
| | An image captioning model to generate movie/t.v show plot from poster. It generates decent plots but is no way perfect. We are still working on improving the model. |
| |
|
| | ## Live demo on Hugging Face Spaces: https://huggingface.co/spaces/deepklarity/poster2plot |
| |
|
| | # Model Details |
| |
|
| | The base model uses a Vision Transformer (ViT) model as an image encoder and GPT-2 as a decoder. |
| |
|
| | We used the following models: |
| |
|
| | * Encoder: [google/vit-base-patch16-224-in21k](https://huggingface.co/google/vit-base-patch16-224-in21k) |
| | * Decoder: [gpt2](https://huggingface.co/gpt2) |
| |
|
| | # Datasets |
| |
|
| | Publicly available IMDb datasets were used to train the model. |
| |
|
| | # How to use |
| |
|
| | ## In PyTorch |
| |
|
| | ```python |
| | import torch |
| | import re |
| | import requests |
| | from PIL import Image |
| | from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel |
| | |
| | # Pattern to ignore all the text after 2 or more full stops |
| | regex_pattern = "[.]{2,}" |
| | |
| | |
| | def post_process(text): |
| | try: |
| | text = text.strip() |
| | text = re.split(regex_pattern, text)[0] |
| | except Exception as e: |
| | print(e) |
| | pass |
| | return text |
| | |
| | |
| | def predict(image, max_length=64, num_beams=4): |
| | pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values |
| | pixel_values = pixel_values.to(device) |
| | |
| | with torch.no_grad(): |
| | output_ids = model.generate( |
| | pixel_values, |
| | max_length=max_length, |
| | num_beams=num_beams, |
| | return_dict_in_generate=True, |
| | ).sequences |
| | |
| | preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True) |
| | pred = post_process(preds[0]) |
| | |
| | return pred |
| | |
| | |
| | model_name_or_path = "deepklarity/poster2plot" |
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | |
| | # Load model. |
| | |
| | model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path) |
| | model.to(device) |
| | print("Loaded model") |
| | |
| | feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path) |
| | print("Loaded feature_extractor") |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True) |
| | if model.decoder.name_or_path == "gpt2": |
| | tokenizer.pad_token = tokenizer.eos_token |
| | |
| | print("Loaded tokenizer") |
| | |
| | url = "https://upload.wikimedia.org/wikipedia/en/2/26/Moana_Teaser_Poster.jpg" |
| | with Image.open(requests.get(url, stream=True).raw) as image: |
| | pred = predict(image) |
| | |
| | print(pred) |
| | |
| | ``` |
| |
|
| |
|
| |
|
| |
|