| | import json, os, math, random |
| | from dataclasses import dataclass |
| | from typing import Dict, List, Any |
| |
|
| | import numpy as np |
| | from datasets import Dataset, DatasetDict |
| | from transformers import (AutoTokenizer, AutoModelForSequenceClassification, |
| | DataCollatorWithPadding, TrainingArguments, Trainer) |
| | import evaluate |
| | from sklearn.metrics import precision_recall_fscore_support |
| |
|
| | |
| | |
| | |
| | MODEL_NAME = "bert-base-uncased" |
| | LABELS = ["mentorship", "entrepreneurship", "startup success"] |
| | TEXT_FIELDS = ["original_text", "summary"] |
| | SEED = 42 |
| | HF_REPO_ID = "4hnk/theme-multilabel-model" |
| |
|
| | random.seed(SEED) |
| | np.random.seed(SEED) |
| |
|
| | |
| | |
| | |
| | |
| | DATA_PATH = "theme_response.json" |
| |
|
| | with open(DATA_PATH, "r", encoding="utf-8") as f: |
| | data = json.load(f)["knowledge_theme_training_data"] |
| |
|
| | def to_example(row: Dict[str, Any]) -> Dict[str, Any]: |
| | text = " ".join([row.get(k, "") for k in TEXT_FIELDS if row.get(k)]) |
| | y = [1 if lbl in row.get("themes", []) else 0 for lbl in LABELS] |
| | return {"text": text.strip(), "labels": y} |
| |
|
| | examples = [to_example(r) for r in data if r.get("original_text")] |
| | ds_full = Dataset.from_list(examples) |
| |
|
| | |
| | |
| | |
| | ds_full = ds_full.shuffle(seed=SEED) |
| | n = len(ds_full) |
| | n_train = max(1, int(0.8 * n)) |
| | ds = DatasetDict({ |
| | "train": ds_full.select(range(n_train)), |
| | "validation": ds_full.select(range(n_train, n)) |
| | }) |
| |
|
| | |
| | |
| | |
| | tok = AutoTokenizer.from_pretrained(MODEL_NAME) |
| |
|
| | def tokenize(batch): |
| | return tok(batch["text"], truncation=True) |
| |
|
| | ds = ds.map(tokenize, batched=True, remove_columns=["text"]) |
| | data_collator = DataCollatorWithPadding(tokenizer=tok) |
| |
|
| | |
| | |
| | |
| | model = AutoModelForSequenceClassification.from_pretrained( |
| | MODEL_NAME, |
| | num_labels=len(LABELS), |
| | problem_type="multi_label_classification" |
| | ) |
| | model.config.id2label = {i: l for i, l in enumerate(LABELS)} |
| | model.config.label2id = {l: i for i, l in enumerate(LABELS)} |
| |
|
| | |
| | |
| | |
| | metric = evaluate.load("accuracy") |
| |
|
| | def sigmoid(x): |
| | return 1 / (1 + np.exp(-x)) |
| |
|
| | def compute_metrics(eval_pred, threshold=0.5): |
| | logits, labels = eval_pred |
| | probs = sigmoid(logits) |
| | preds = (probs >= threshold).astype(int) |
| |
|
| | |
| | micro_p, micro_r, micro_f1, _ = precision_recall_fscore_support( |
| | labels, preds, average="micro", zero_division=0 |
| | ) |
| | macro_p, macro_r, macro_f1, _ = precision_recall_fscore_support( |
| | labels, preds, average="macro", zero_division=0 |
| | ) |
| | |
| | out = { |
| | "micro/precision": micro_p, |
| | "micro/recall": micro_r, |
| | "micro/f1": micro_f1, |
| | "macro/precision": macro_p, |
| | "macro/recall": macro_r, |
| | "macro/f1": macro_f1, |
| | } |
| | return out |
| |
|
| | |
| | |
| | |
| | args = TrainingArguments( |
| | output_dir="./theme_model_outputs", |
| | evaluation_strategy="epoch", |
| | save_strategy="epoch", |
| | learning_rate=2e-5, |
| | per_device_train_batch_size=8, |
| | per_device_eval_batch_size=16, |
| | num_train_epochs=10, |
| | weight_decay=0.01, |
| | load_best_model_at_end=True, |
| | metric_for_best_model="micro/f1", |
| | greater_is_better=True, |
| | push_to_hub=True, |
| | hub_model_id=HF_REPO_ID |
| | ) |
| |
|
| | |
| | |
| | |
| | trainer = Trainer( |
| | model=model, |
| | args=args, |
| | train_dataset=ds["train"], |
| | eval_dataset=ds["validation"], |
| | tokenizer=tok, |
| | data_collator=data_collator, |
| | compute_metrics=compute_metrics |
| | ) |
| |
|
| | trainer.train() |
| | trainer.evaluate() |
| |
|
| | |
| | |
| | |
| | trainer.push_to_hub() |
| |
|