YAML Metadata Warning:empty or missing yaml metadata in repo card

Check out the documentation for more information.

SenseVoice-Small-ko (Fine-tuned SenseVoiceSmall on EDIE dataset)

์ด ๋ฆฌํฌ์ง€ํ„ฐ๋ฆฌ๋Š” SenseVoiceSmall๋ฅผ ํ•œ๊ตญ์–ด ์Œ์„ฑ/๊ฐ์ •/์ด๋ฒคํŠธ ์ธ์‹์šฉ EDIE ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํŒŒ์ธํŠœ๋‹ํ•œ ๋ชจ๋ธ์ž…๋‹ˆ๋‹ค.

  • ๋ฒ ์ด์Šค ๋ชจ๋ธ: iic/SenseVoiceSmall
  • ํ…Œ์Šคํฌ: STT (ASR) + Emotion (SER) + Event (AED)
  • ์ฃผ์š” ๋ผ๋ฒจ:
    • ํ…์ŠคํŠธ ๋ผ๋ฒจ
    • ๊ฐ์ • ๋ผ๋ฒจ: <|HAPPY|>, <|SAD|>, <|ANGRY|>, <|NEUTRAL|>, <|FEARFUL|>, <|DISGUSTED|>, <|SURPRISED|>

0. ๋ชจ๋ธ ์ž…์ถœ๋ ฅ ํฌ๋ฉง

์ž…๋ ฅ

  • input: ๋‹จ์ผ wav ๊ฒฝ๋กœ ๋˜๋Š” ๊ฒฝ๋กœ ๋ฆฌ์ŠคํŠธ

์ถœ๋ ฅ ์ถœ๋ ฅ ์˜ˆ์‹œ (AutoModel)

  • text: ์ธ์‹๋œ ํ…์ŠคํŠธ
  • language: ์–ธ์–ด ID (<|ko|> ๋“ฑ)
  • emo: ๊ฐ์ • ๋ผ๋ฒจ (<|HAPPY|>, <|SAD|> ๋“ฑ)
  • event: ์ด๋ฒคํŠธ ๋ผ๋ฒจ (<|Speech|>, <|BGM|> ๋“ฑ)

1. ์„ค์น˜

pip install -U "funasr>=1.2.7" torch

GPU๋ฅผ ์‚ฌ์šฉํ•  ๊ฒฝ์šฐ ์‚ฌ์ „์— CUDA ํ˜ธํ™˜ PyTorch๋ฅผ ์„ค์น˜ํ•ด ์ฃผ์„ธ์š”

2. ๊ฐ„๋‹จํ•˜๊ฒŒ ๋ชจ๋ธ ์‚ฌ์šฉํ•˜๊ธฐ

FunASR์˜ AutoModel์„ ์ด์šฉํ•˜์—ฌ ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ชจ๋ธ ํ—ˆ๋ธŒ์—์„œ ๋ชจ๋ธ ๋ ˆํŒŒ์ง€ํ† ๋ฆฌ์˜ ๋ชจ๋ธ์„ ๋ฐ”๋กœ ๋กœ๋“œํ•ด์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

#!/usr/bin/env python3
from pathlib import Path
import os
import argparse

from huggingface_hub import snapshot_download
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess

HF_REPO_ID = "AeiROBOT/SenseVoice-Small-ko"   # ์—…๋กœ๋“œํ•œ HF ๋ฆฌํฌ ID
LOCAL_DIR = "/home/khw/Workspace/SenseVoice/hf_models/SenseVoice-Small-ko"

# ----- SenseVoice ํ† ํฐ ํŒŒ์„œ -----
LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"}
EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"}
EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"}
WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"}


def _consume(prefixes, text: str):
    for p in prefixes:
        if text.startswith(p):
            return p, text[len(p):]
    return None, text


def parse_sensevoice_text(raw: str):
    """SenseVoice ์ถœ๋ ฅ ๋ฌธ์ž์—ด์—์„œ (lang, emo, event, with_itn, text) ๋ถ„๋ฆฌ.

    ์˜ˆ:
        "<|ko|><|NEUTRAL|><|Speech|><|withitn|>์กฐ ๊ธˆ๋งŒ ์ƒ๊ฐ ์„ ํ•˜ ๋ฉด์„œ ์‚ด ๋ฉด ํ›จ์”ฌ ํŽธํ•  ๊ฑฐ์•ผ." ->
        {
          "language": "<|ko|>",
          "emo": "<|NEUTRAL|>",
          "event": "<|Speech|>",
          "with_itn": "<|withitn|>",
          "text": "์กฐ ๊ธˆ๋งŒ ์ƒ๊ฐ ์„ ํ•˜ ๋ฉด์„œ ์‚ด ๋ฉด ํ›จ์”ฌ ํŽธํ•  ๊ฑฐ์•ผ."
        }
    """
    if not raw:
        return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""}

    rest = raw.strip()
    lang, rest = _consume(LANG_TOKENS, rest)
    emo, rest = _consume(EMO_TOKENS, rest)
    event, rest = _consume(EVENT_TOKENS, rest)
    with_itn, rest = _consume(WITH_ITN_TOKENS, rest)

    clean_text = rest.strip()
    return {
        "language": lang,
        "emo": emo,
        "event": event,
        "with_itn": with_itn,
        "text": clean_text,
    }


def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--wav_file", default="dataset/wav_dataset/DISGUSTED/test_2025_12_12_040201.wav", help="pretrained ๋ชจ๋ธ ์ด๋ฆ„ ๋˜๋Š” ๋กœ์ปฌ ๋””๋ ‰ํ„ฐ๋ฆฌ")
    return p.parse_args()

def get_model():
    local_path = snapshot_download(
        repo_id=HF_REPO_ID,
        repo_type="model",
        local_dir=LOCAL_DIR,
        local_dir_use_symlinks=False,
        token=os.environ.get("HUGGINGFACE_HUB_TOKEN"),  # private ์ด๋ฏ€๋กœ ํ•„์š”
    )
    print("๋‹ค์šด๋กœ๋“œ ๊ฒฝ๋กœ:", local_path)

    # 2) AutoModel์— ๋กœ์ปฌ ๊ฒฝ๋กœ๋ฅผ ๋„˜๊ฒจ์„œ ์‚ฌ์šฉ
    model_dir = local_path  # ๋˜๋Š” LOCAL_DIR

    model = AutoModel(
        model=model_dir,
        trust_remote_code=True,
        remote_code=str(Path(model_dir) / "model.py"),  # HF ๋ฆฌํฌ์— ์žˆ๋Š” model.py ์‚ฌ์šฉ
        vad_model="fsmn-vad",
        vad_kwargs={"max_single_segment_time": 30000},
        device="cuda:0",
    )
    
    return model

def main():
    args = parse_args()
    wav_path = args.wav_file
    
    model = get_model()
    
    res = model.generate(
        input=wav_path,
        cache={},
        language="auto",   # ๋˜๋Š” "ko"
        use_itn=True,
        batch_size_s=60,
        merge_vad=True,
        merge_length_s=15,
    )

    raw_text = res[0]["text"]
    parsed = parse_sensevoice_text(raw_text)

    # ITN ํ›„์ฒ˜๋ฆฌ
    pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else ""

    print("=== Raw ===")
    print(raw_text)
    print("=== Parsed ===")
    print("lang   :", parsed["language"])
    print("emo    :", parsed["emo"])
    print("event  :", parsed["event"])
    print("withitn:", parsed["with_itn"])
    print("text   :", pretty_text)


if __name__ == "__main__":
    main()

3. ํ•™์Šต ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ํ‰๊ฐ€ํ•˜๊ธฐ

#!/usr/bin/env python3
import os
import json
import argparse
import unicodedata
from pathlib import Path
from typing import List, Dict, Tuple, Optional

import torch
from funasr import AutoModel
from funasr.utils.postprocess_utils import rich_transcription_postprocess


# =======================
# SenseVoice ํ† ํฐ ํŒŒ์„œ
# =======================
LANG_TOKENS = {"<|zh|>", "<|en|>", "<|yue|>", "<|ja|>", "<|ko|>", "<|nospeech|>"}
EMO_TOKENS = {"<|HAPPY|>", "<|SAD|>", "<|ANGRY|>", "<|NEUTRAL|>", "<|FEARFUL|>", "<|DISGUSTED|>", "<|SURPRISED|>"}
EVENT_TOKENS = {"<|BGM|>", "<|Speech|>", "<|Applause|>", "<|Laughter|>", "<|Cry|>", "<|Sneeze|>", "<|Breath|>", "<|Cough|>"}
WITH_ITN_TOKENS = {"<|withitn|>", "<|woitn|>"}


def _consume(prefixes, text: str):
    for p in prefixes:
        if text.startswith(p):
            return p, text[len(p):]
    return None, text


def parse_sensevoice_text(raw: str) -> Dict[str, Optional[str]]:
    if not raw:
        return {"language": None, "emo": None, "event": None, "with_itn": None, "text": ""}

    rest = raw.strip()
    lang, rest = _consume(LANG_TOKENS, rest)
    emo, rest = _consume(EMO_TOKENS, rest)
    event, rest = _consume(EVENT_TOKENS, rest)
    with_itn, rest = _consume(WITH_ITN_TOKENS, rest)

    clean_text = rest.strip()
    return {
        "language": lang,
        "emo": emo,
        "event": event,
        "with_itn": with_itn,
        "text": clean_text,
    }


# =======================
# ํ…์ŠคํŠธ ์ •๊ทœํ™” & ์ง€ํ‘œ
# =======================

def normalize_text(s: str, lower: bool, strip_punct: bool, strip_spaces: bool) -> str:
    if s is None:
        return ""
    t = s
    if lower:
        t = t.lower()
    if strip_punct:
        t = "".join(ch for ch in t if not unicodedata.category(ch).startswith("P"))
    if strip_spaces:
        t = "".join(t.split())
    return t


def _levenshtein(a: List[str], b: List[str]) -> int:
    n, m = len(a), len(b)
    if n == 0:
        return m
    if m == 0:
        return n
    prev = list(range(m + 1))
    for i in range(1, n + 1):
        curr = [i] + [0] * m
        ai = a[i - 1]
        for j in range(1, m + 1):
            cost = 0 if ai == b[j - 1] else 1
            curr[j] = min(
                prev[j] + 1,
                curr[j - 1] + 1,
                prev[j - 1] + cost,
            )
        prev = curr
    return prev[m]


def cer(ref: str, hyp: str) -> float:
    r = list(ref)
    h = list(hyp)
    dist = _levenshtein(r, h)
    return dist / max(1, len(r))


def wer(ref: str, hyp: str) -> float:
    r = ref.split()
    h = hyp.split()
    dist = _levenshtein(r, h)
    return dist / max(1, len(r))


def norm_emo(label: Optional[str]) -> str:
    if not label:
        return ""
    t = label.strip()
    if t.startswith("<|") and t.endswith("|>"):
        t = t[2:-2]
    return t.upper()


# =======================
# IO & argparse
# =======================

def parse_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model-dir", default="/home/khw/Workspace/SenseVoice/outputs", help="finetune ์‚ฐ์ถœ๋ฌผ ๋””๋ ‰ํ„ฐ๋ฆฌ")
    p.add_argument("--jsonl", default="/home/khw/Workspace/SenseVoice/data/train.jsonl", help="์ž…๋ ฅ JSONL ๊ฒฝ๋กœ")
    p.add_argument("--base-audio-dir", default="/home/khw/Workspace/SenseVoice", help="source ์ƒ๋Œ€๊ฒฝ๋กœ์˜ ๊ธฐ์ค€ ๋””๋ ‰ํ„ฐ๋ฆฌ")
    p.add_argument("--remote-code", default="/home/khw/Workspace/SenseVoice/model.py", help="SenseVoice ๋ชจ๋ธ ๊ตฌํ˜„ ๊ฒฝ๋กœ")
    p.add_argument("--device", default=None, help="cuda:0 / cpu (๋ฏธ์ง€์ • ์‹œ ์ž๋™ ๊ฒฐ์ •)")
    p.add_argument("--batch-size", type=int, default=64, help="๋ฐฐ์น˜ ํฌ๊ธฐ(์งง์€ ์Œ์› ๋‹ค์ˆ˜ ๊ฐ€์ •)")
    p.add_argument("--use-best-ckpt", action="store_true", help="model.pt.best๋ฅผ model.pt๋กœ ์‹ฌ๋ณผ๋ฆญ ๋งํฌ ์ƒ์„ฑ")
    p.add_argument("--lang", default="ko", choices=["auto", "zh", "en", "yue", "ja", "ko", "nospeech"], help="์–ธ์–ด ๊ฐ•์ œ ์„ค์ •. ๊ธฐ๋ณธ ko")
    p.add_argument("--lower", action="store_true", help="์ •๋ฐ€๋„ ๊ณ„์‚ฐ ์‹œ ์†Œ๋ฌธ์žํ™”")
    p.add_argument("--strip-punct", action="store_true", help="์ •๋ฐ€๋„ ๊ณ„์‚ฐ ์‹œ ๋ฌธ์žฅ๋ถ€ํ˜ธ ์ œ๊ฑฐ")
    p.add_argument("--strip-spaces", action="store_true", help="์ •๋ฐ€๋„ ๊ณ„์‚ฐ ์‹œ ๋ชจ๋“  ๊ณต๋ฐฑ ์ œ๊ฑฐ")
    p.add_argument("--out", default="/home/khw/Workspace/SenseVoice/results/preds_train.jsonl", help="์ถ”๋ก  ๊ฒฐ๊ณผ JSONL")
    return p.parse_args()


def _find_latest_epoch_ckpt(model_dir: Path) -> Optional[Path]:
    """model.pt.ep* ์ค‘์—์„œ ๊ฐ€์žฅ ํฐ epoch ๋ฒˆํ˜ธ๋ฅผ ๊ฐ€์ง„ ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์ฐพ๋Š”๋‹ค."""
    candidates = []
    for p in model_dir.glob("model.pt.ep*"):
        name = p.name
        try:
            # ์ด๋ฆ„์—์„œ ์ˆซ์ž ๋ถ€๋ถ„๋งŒ ํŒŒ์‹ฑ: model.pt.ep50 -> 50
            ep_str = name.split("model.pt.ep", 1)[1]
            ep = int(ep_str)
            candidates.append((ep, p))
        except (IndexError, ValueError):
            # ํŒจํ„ด์ด ์•ˆ ๋งž์œผ๋ฉด ๋ฌด์‹œ
            continue

    if not candidates:
        return None

    candidates.sort(key=lambda x: x[0])  # epoch ์˜ค๋ฆ„์ฐจ์ˆœ ์ •๋ ฌ
    return candidates[-1][1]  # ๊ฐ€์žฅ ํฐ epoch


def prepare_checkpoint(model_dir: Path) -> Path:
    """์ฃผ์–ด์ง„ model_dir ์•ˆ์—์„œ ์‚ฌ์šฉํ•  ์ฒดํฌํฌ์ธํŠธ๋ฅผ ์„ ํƒํ•˜๊ณ , model.pt๋ฅผ ์ค€๋น„ํ•œ๋‹ค.

    ์šฐ์„ ์ˆœ์œ„:
      1) model.pt.best
      2) model.pt.ep* ์ค‘ ๊ฐ€์žฅ ํฐ epoch
      3) model.pt (๊ธฐ์กด ํŒŒ์ผ)

    ์…‹ ๋‹ค ์—†์œผ๋ฉด SystemExit์œผ๋กœ ์ข…๋ฃŒ.

    ์„ ํƒ๋œ ํŒŒ์ผ์ด model.pt๊ฐ€ ์•„๋‹ˆ๋ผ๋ฉด, model.pt๋ฅผ ํ•ด๋‹น ํŒŒ์ผ์„ ๊ฐ€๋ฆฌํ‚ค๋Š”
    ์‹ฌ๋ณผ๋ฆญ ๋งํฌ(๋˜๋Š” ๋ณต์‚ฌ๋ณธ)์œผ๋กœ ๋งŒ๋“ ๋‹ค.
    """
    best = model_dir / "model.pt.best"
    target = model_dir / "model.pt"  # AutoModel์ด ์ตœ์ข…์ ์œผ๋กœ ๋ณด๊ฒŒ ๋  ํŒŒ์ผ

    chosen: Optional[Path] = None

    # 1) model.pt.best ์ตœ์šฐ์„ 
    if best.exists():
        chosen = best
        reason = "model.pt.best"
    else:
        # 2) ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ epoch์˜ model.pt.ep*
        latest_ep = _find_latest_epoch_ckpt(model_dir)
        if latest_ep is not None:
            chosen = latest_ep
            reason = latest_ep.name
        # 3) ๊ธฐ์กด model.pt
        elif target.exists():
            chosen = target
            reason = "existing model.pt"
        else:
            reason = "(none)"

    if chosen is None:
        raise SystemExit(
            f"[fatal] No checkpoint found in {model_dir}. "
            f"Expected one of: model.pt.best, model.pt.ep*, model.pt. Program will exit."
        )

    # ์„ ํƒ๋œ ์ฒดํฌํฌ์ธํŠธ๋ฅผ model.pt๋กœ ๋งž์ถฐ์ค€๋‹ค (๋งํฌ ๋˜๋Š” ๋ณต์‚ฌ)
    if chosen != target:
        if target.exists() or target.is_symlink():
            try:
                target.unlink()
            except Exception as e:
                print(f"[warn] failed to remove existing {target}: {e}")

        try:
            # ์ƒ๋Œ€ ์ด๋ฆ„์œผ๋กœ ์‹ฌ๋ณผ๋ฆญ ๋งํฌ ์ƒ์„ฑ
            target.symlink_to(chosen.name)
            print(f"[info] using checkpoint: {chosen.name} (linked as model.pt)")
        except Exception as e:
            # ์ผ๋ถ€ ํŒŒ์ผ์‹œ์Šคํ…œ/๊ถŒํ•œ ํ™˜๊ฒฝ์—์„œ symlink๊ฐ€ ์•ˆ ๋  ์ˆ˜ ์žˆ์œผ๋ฏ€๋กœ, ๋ณต์‚ฌ๋กœ ํด๋ฐฑ
            print(f"[warn] symlink failed ({e}), will try to copy instead.")
            import shutil
            try:
                shutil.copy2(str(chosen), str(target))
                print(f"[info] using checkpoint: {chosen.name} (copied to model.pt)")
            except Exception as e2:
                raise SystemExit(
                    f"[fatal] failed to prepare checkpoint at {target}: {e2}. Program will exit."
                )
    else:
        print(f"[info] using checkpoint: {reason}")

    return chosen


def load_items(jsonl_path: Path) -> List[Dict]:
    items = []
    with jsonl_path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            try:
                obj = json.loads(line)
                items.append(obj)
            except Exception as e:
                print(f"[warn] skip bad line: {e}")
    return items


def to_abs_paths(items: List[Dict], base_audio_dir: Path) -> Tuple[List[Dict], int]:
    missing = 0
    for it in items:
        src = it.get("source")
        if src:
            p = (base_audio_dir / src).resolve()
            if not p.exists():
                missing += 1
            it["abs_source"] = str(p)
        else:
            it["abs_source"] = None
            missing += 1
    return items, missing


def batched(iterable, n: int):
    batch = []
    for x in iterable:
        batch.append(x)
        if len(batch) == n:
            yield batch
            batch = []
    if batch:
        yield batch


# =======================
# main
# =======================

def main():
    args = parse_args()

    model_dir = Path(args.model_dir)
    jsonl_path = Path(args.jsonl)
    base_audio_dir = Path(args.base_audio_dir)

    # ์ฒดํฌํฌ์ธํŠธ ์šฐ์„ ์ˆœ์œ„ ์ ์šฉ: model.pt.best > model.pt.ep* (์ตœ๋Œ€ epoch) > model.pt
    ckpt = prepare_checkpoint(model_dir)
    print(f"[info] final checkpoint file: {ckpt}")

    device = args.device or ("cuda:0" if torch.cuda.is_available() else "cpu")

    # model.py(remote_code)๋Š” ๋ฐ˜๋“œ์‹œ ์กด์žฌํ•ด์•ผ ํ•œ๋‹ค. ์—†์œผ๋ฉด ๋ฐ”๋กœ ์ข…๋ฃŒ.
    remote_code_path = Path(args.remote_code)
    if not remote_code_path.exists():
        raise SystemExit(
            f"[fatal] remote_code not found at {remote_code_path}. "
            f"Expected model.py for SenseVoice. Program will exit."
        )

    trust_remote = True

    model = AutoModel(
        model=str(model_dir),  # ๋กœ์ปฌ ๋””๋ ‰ํ„ฐ๋ฆฌ๋งŒ ์‚ฌ์šฉ
        trust_remote_code=trust_remote,
        remote_code=str(remote_code_path),
        device=device,
        vad_model=None,
    )

    items = load_items(jsonl_path)
    items, _ = to_abs_paths(items, base_audio_dir)

    valid_items = [it for it in items if it.get("abs_source") and Path(it["abs_source"]).exists()]
    missing = len(items) - len(valid_items)
    if missing:
        print(f"[warn] {missing} items skipped due to missing files")

    out_path = Path(args.out)
    out_path.parent.mkdir(parents=True, exist_ok=True)

    total = len(valid_items)
    print(f"[info] total inputs used: {total}, device: {device}, language: {args.lang}")
    if total == 0:
        print("[exit] No valid audio found. Check --base-audio-dir or 'source' paths.")
        with out_path.open("w", encoding="utf-8") as wf:
            pass
        return

    # ์ง€ํ‘œ ๋ˆ„์ 
    exact_matches = 0
    cer_sum = 0.0
    wer_sum = 0.0
    text_pairs = 0

    emo_correct = 0
    emo_total = 0

    written = 0
    with out_path.open("w", encoding="utf-8") as wf:
        for batch in batched(valid_items, args.batch_size):
            wav_list = [b["abs_source"] for b in batch]

            try:
                res = model.generate(
                    input=wav_list,
                    cache={},
                    language=args.lang,
                    use_itn=True,
                    batch_size=len(wav_list),
                )
            except Exception as e:
                print(f"[error] inference failed on batch starting key={batch[0].get('key')}: {e}")
                continue

            for it, r in zip(batch, res):
                raw_text = r.get("text", "") or ""
                parsed = parse_sensevoice_text(raw_text)
                pretty_text = rich_transcription_postprocess(parsed["text"]) if parsed["text"] else ""

                ref_text = it.get("target") or ""

                # ํ…์ŠคํŠธ ์ง€ํ‘œ
                if ref_text:
                    nt_ref = normalize_text(ref_text, args.lower, args.strip_punct, args.strip_spaces)
                    nt_hyp = normalize_text(pretty_text, args.lower, args.strip_punct, args.strip_spaces)

                    if nt_ref == nt_hyp:
                        exact_matches += 1
                    cer_sum += cer(nt_ref, nt_hyp)
                    wer_sum += wer(nt_ref, nt_hyp)
                    text_pairs += 1

                # ๊ฐ์ • ์ง€ํ‘œ
                tgt_emo_n = norm_emo(it.get("emo_target"))
                pred_emo_n = norm_emo(parsed["emo"])
                if tgt_emo_n:
                    emo_total += 1
                    if pred_emo_n == tgt_emo_n:
                        emo_correct += 1

                out_obj = {
                    "key": it.get("key"),
                    "audio": it.get("abs_source"),
                    "pred_raw": raw_text,
                    "pred_text": pretty_text,
                    "ref_text": ref_text,
                    "pred_language": parsed["language"],
                    "pred_emo": pred_emo_n or parsed["emo"] or "",
                    "ref_emo": tgt_emo_n or it.get("emo_target") or "",
                    "pred_event": parsed["event"] or "",
                    "with_itn": parsed["with_itn"] or "",
                }
                wf.write(json.dumps(out_obj, ensure_ascii=False) + "\n")

                # ===== ์‚ฌ๋žŒ์ด ๋ณด๊ธฐ ์ข‹์€ per-sample ์ถœ๋ ฅ =====
                idx = written + 1
                print("\n[{}] key={}".format(idx, it.get("key")))
                print("REF_TEXT :", ref_text)
                print("REF_EMO  :", tgt_emo_n or it.get("emo_target"))
                print("PRED_TEXT:", pretty_text)
                print("PRED_EMO :", pred_emo_n or parsed["emo"])  # ํ† ํฐ ๊ทธ๋Œ€๋กœ ๋ณด์—ฌ์ค˜๋„ ๋จ
                print("PRED_EVT :", parsed["event"])  # ์ด๋ฒคํŠธ๋„ ๊ฐ™์ด ํ™•์ธ
                print("-" * 80)

                written += 1

    # ์š”์•ฝ ์ถœ๋ ฅ
    print("\n===== Summary =====")
    print(f"Samples inferred: {written}")
    if text_pairs > 0:
        exact_acc = exact_matches / text_pairs * 100.0
        avg_cer = cer_sum / text_pairs
        avg_wer = wer_sum / text_pairs
        print(f"Text pairs (with ref): {text_pairs}")
        print(f"- Exact match accuracy: {exact_acc:.2f}%")
        print(f"- Avg CER: {avg_cer:.4f}")
        print(f"- Avg WER: {avg_wer:.4f}")
    else:
        print("No text references found; text metrics skipped.")

    if emo_total > 0:
        emo_acc = emo_correct / emo_total * 100.0
        print(f"Emotion pairs: {emo_total}")
        print(f"- Emotion accuracy: {emo_acc:.2f}%")
    else:
        print("No emotion references found; emotion metrics skipped.")

    print(f"Results saved to: {out_path}")


if __name__ == "__main__":
    main()

4. ํ•™์Šต ํ›„ ํ—ˆ๊น…ํŽ˜์ด์Šค์— ๋ชจ๋ธ ์—…๋กœ๋“œ

upload_model_to_huggingface.py


#!/usr/bin/env python3
import os
from pathlib import Path

from huggingface_hub import HfApi, create_repo, upload_folder

# ===== ์‚ฌ์šฉ์ž ์„ค์ • =====
# ์‹ค์ œ๋กœ ๋งŒ๋“ค Hugging Face ๋ชจ๋ธ repo ID (์˜ˆ์‹œ)
REPO_ID = "AeiROBOT/SenseVoice-Small-ko"  # <-- ์›ํ•˜๋Š” ์ด๋ฆ„์œผ๋กœ ์ˆ˜์ •

# ์—…๋กœ๋“œํ•  ๋กœ์ปฌ ํด๋” (ํ•™์Šต ๊ฒฐ๊ณผ)
MODEL_DIR = Path("/home/khw/Workspace/SenseVoice/outputs")

# ๋กœ์ปฌ์— ์žˆ๋Š” model.py๋ฅผ ํ•จ๊ป˜ ์˜ฌ๋ฆฌ๊ณ  ์‹ถ์œผ๋ฉด (FunASR/SenseVoice์šฉ)
# outputs ์•ˆ์— ์ด๋ฏธ ๋ณต์‚ฌํ•ด ๋‘์—ˆ์œผ๋ฉด ์ƒ๋žต ๊ฐ€๋Šฅ
EXTRA_FILES = [
    Path("/home/khw/Workspace/SenseVoice/model.py"),  # ์—†์œผ๋ฉด ์ฃผ์„ ์ฒ˜๋ฆฌ
]


def main():
    # 1) ํ† ํฐ ๊ฐ€์ ธ์˜ค๊ธฐ (ํ™˜๊ฒฝ๋ณ€์ˆ˜ ์‚ฌ์šฉ ๊ถŒ์žฅ)
    # ๋ฏธ๋ฆฌ export HUGGINGFACE_HUB_TOKEN=hf_xxx ํ•˜๊ธฐ
    token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
    if token is None:
        raise RuntimeError(
            "HUGGINGFACE_HUB_TOKEN ํ™˜๊ฒฝ๋ณ€์ˆ˜๊ฐ€ ์„ค์ •๋˜์–ด ์žˆ์ง€ ์•Š์Šต๋‹ˆ๋‹ค. "
            "https://huggingface.co/settings/tokens ์—์„œ ํ† ํฐ์„ ๋งŒ๋“ค๊ณ ,\n"
            "export HUGGINGFACE_HUB_TOKEN=hf_xxx ๋กœ ์„ค์ •ํ•œ ๋’ค ๋‹ค์‹œ ์‹คํ–‰ํ•˜์„ธ์š”."
        )

    api = HfApi()

    # 2) ๋ฆฌํฌ์ง€ํ„ฐ๋ฆฌ ์ƒ์„ฑ (์ด๋ฏธ ์žˆ์œผ๋ฉด exist_ok=True ๋กœ ๊ทธ๋ƒฅ ํ†ต๊ณผ)
    create_repo(
        repo_id=REPO_ID,
        token=token,
        private=True,   # ๋น„๊ณต๊ฐœ๋กœ ์˜ฌ๋ฆฌ๋ ค๋ฉด True
        exist_ok=True,
        repo_type="model",
    )

    # 3) ์ถ”๊ฐ€๋กœ ์˜ฌ๋ฆด ํŒŒ์ผ(model.py ๋“ฑ)์„ outputs ์•ˆ์œผ๋กœ ๋ณต์‚ฌ (์„ ํƒ)
    #    -> HF ๋ฆฌํฌ root์— README.md, model.pt, config.yaml, configuration.json, model.py ๋“ฑ์ด ๊ฐ™์ด ์žˆ๋„๋ก ์ถ”์ฒœ
    for extra in EXTRA_FILES:
        if extra.is_file():
            target = MODEL_DIR / extra.name
            if not target.exists():
                print(f"[info] copy {extra} -> {target}")
                target.write_bytes(extra.read_bytes())
        else:
            print(f"[warn] extra file not found: {extra}")

    # 3-1) ๋ชจ๋ธ ์นด๋“œ(README) ์—…๋กœ๋“œ: ์‹คํ–‰ ์œ„์น˜(CWD)์˜ README_huggingface.md๋ฅผ outputs/README.md๋กœ ๋ณต์‚ฌ
    #      - HF ๋ชจ๋ธ ํ—ˆ๋ธŒ๋Š” repo ๋ฃจํŠธ์˜ README.md๋ฅผ ๋ชจ๋ธ ์นด๋“œ๋กœ ์ธ์‹ํ•ฉ๋‹ˆ๋‹ค.
    readme_src = Path.cwd() / "README_huggingface.md"
    readme_dst = MODEL_DIR / "README.md"
    if readme_src.is_file():
        print(f"[info] copy {readme_src} -> {readme_dst}")
        readme_dst.write_text(readme_src.read_text(encoding="utf-8"), encoding="utf-8")
    else:
        print(f"[warn] README_huggingface.md not found in CWD: {Path.cwd()}")

    # 4) ํด๋” ํ†ต์งธ๋กœ ์—…๋กœ๋“œ
    print(f"[info] uploading folder: {MODEL_DIR} -> {REPO_ID}")
    upload_folder(
        repo_id=REPO_ID,
        folder_path=str(MODEL_DIR),
        path_in_repo=".",         # ๋ฆฌํฌ ๋ฃจํŠธ์— ๊ทธ๋Œ€๋กœ ์˜ฌ๋ฆฌ๊ธฐ
        token=token,
        repo_type="model",
        ignore_patterns=[
            "model.pt.ep*",   # ์ฒดํฌํฌ์ธํŠธ๋“ค ์ œ์™ธ
            "*.pt.ep*",       # ํ˜น์‹œ ๋‹ค๋ฅธ ํŒŒ์ผ๋ช…๋„ ๋น„์Šทํ•˜๊ฒŒ ์ฐํžˆ๋ฉด ๊ฐ™์ด ์ œ์™ธ
        ],
    )

    print("[done] uploaded to:", f"https://huggingface.co/{REPO_ID}")


if __name__ == "__main__":
    main()

Downloads last month
5
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Collection including AeiROBOT/SenseVoice-Small-ko