import gradio as gr from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, pipeline ) # ----------------------------- # Load Your Classifier # ----------------------------- tokenizer = AutoTokenizer.from_pretrained("alemmrr/finbert-gics-sector-classifier") model = AutoModelForSequenceClassification.from_pretrained("alemmrr/finbert-gics-sector-classifier") clf = pipeline( "text-classification", model=model, tokenizer=tokenizer, top_k=None, device=-1 ) # ----------------------------- # Load NER Model (for auto-formatting) # ----------------------------- ner_pipeline = pipeline( "ner", model="Jean-Baptiste/roberta-large-ner-english", aggregation_strategy="simple" ) # ----------------------------- # Helper: Format headline (Variant 3 Prefixing) # ----------------------------- def format_headline_variant3(headline): ents = ner_pipeline(headline) # Buckets (same as training Variant-3) entity_buckets = { "ORG": [], "LOC": [], "PER": [], "GPE": [] } # Fill buckets for ent in ents: tag = ent["entity_group"] word = ent["word"] if tag in entity_buckets: entity_buckets[tag].append(word) # Build prefix prefix = "" for tag, values in entity_buckets.items(): if values: prefix += f"[{tag}] " + " | ".join(values) + " " # Append [SEP] if any prefix exists if prefix: prefix = prefix.strip() + " [SEP] " # Return final formatted input for classifier return prefix + headline # ----------------------------- # Main Prediction Function # ----------------------------- def predict(text): # Auto-format headline → Variant 3 formatted = format_headline_variant3(text) outputs = clf(formatted) # FIX: Flatten output if it's list-of-lists if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list): outputs = outputs[0] scores = [ { "label": o["label"], "confidence": round(float(o["score"]) * 100, 2) } for o in outputs ] # Sort by confidence scores = sorted(scores, key=lambda x: x["confidence"], reverse=True) return scores # ----------------------------- # Gradio Interface # ----------------------------- demo = gr.Interface( fn=predict, inputs=gr.Textbox(lines=3, label="Enter a financial headline (plain text)"), outputs=gr.JSON(label="All Sector Scores"), title="FinBERT GICS Sector Classifier (Auto-Formatted)", description=( "Enter a plain financial news headline. The app automatically applies NER tagging " ), ) demo.launch()