Spaces:
Paused
Paused
| # Necessary imports | |
| import sys | |
| from typing import Dict | |
| import torch | |
| from transformers import pipeline | |
| import gradio as gr | |
| # Local imports | |
| from src.logger import logging | |
| from src.exception import CustomExceptionHandling | |
| # Load the zero-shot classification model | |
| classifier = pipeline( | |
| "zero-shot-classification", | |
| model="MoritzLaurer/ModernBERT-large-zeroshot-v2.0", | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| def ZeroShotTextClassification( | |
| text_input: str, candidate_labels: str, multi_label: bool | |
| ) -> Dict[str, float]: | |
| """ | |
| Performs zero-shot classification on the given text input and candidate labels. | |
| Args: | |
| - text_input: The input text to classify. | |
| - candidate_labels: A comma-separated string of candidate labels. | |
| - multi_label: A boolean indicating whether to allow the model to choose multiple classes. | |
| Returns: | |
| Dictionary containing label-score pairs. | |
| """ | |
| try: | |
| # Check if the input and candidate labels are valid | |
| if not text_input or not candidate_labels: | |
| gr.Warning("Please provide valid input and candidate labels") | |
| # Split and clean the candidate labels | |
| labels = [label.strip() for label in candidate_labels.split(",")] | |
| # Log the classification attempt | |
| logging.info(f"Attempting classification with {len(labels)} labels") | |
| # Perform zero-shot classification | |
| hypothesis_template = "This text is about {}" | |
| prediction = classifier( | |
| text_input, | |
| labels, | |
| hypothesis_template=hypothesis_template, | |
| multi_label=multi_label, | |
| ) | |
| # Return the classification results | |
| logging.info("Classification completed successfully") | |
| return { | |
| prediction["labels"][i]: prediction["scores"][i] | |
| for i in range(len(prediction["labels"])) | |
| } | |
| # Handle exceptions that may occur during the process | |
| except Exception as e: | |
| # Custom exception handling | |
| raise CustomExceptionHandling(e, sys) from e | |