Spaces:
Sleeping
Sleeping
| from pathlib import Path | |
| from typing import Any, Optional, Union | |
| import torch | |
| from safetensors import safe_open | |
| from safetensors.torch import save_file | |
| from style_bert_vits2.logging import logger | |
| def load_safetensors( | |
| checkpoint_path: Union[str, Path], | |
| model: torch.nn.Module, | |
| for_infer: bool = False, | |
| ) -> tuple[torch.nn.Module, Optional[int]]: | |
| """ | |
| 指定されたパスから safetensors モデルを読み込み、モデルとイテレーションを返す。 | |
| Args: | |
| checkpoint_path (Union[str, Path]): モデルのチェックポイントファイルのパス | |
| model (torch.nn.Module): 読み込む対象のモデル | |
| for_infer (bool): 推論用に読み込むかどうかのフラグ | |
| Returns: | |
| tuple[torch.nn.Module, Optional[int]]: 読み込まれたモデルとイテレーション回数(存在する場合) | |
| """ | |
| tensors: dict[str, Any] = {} | |
| iteration: Optional[int] = None | |
| with safe_open(str(checkpoint_path), framework="pt", device="cpu") as f: # type: ignore | |
| for key in f.keys(): | |
| if key == "iteration": | |
| iteration = f.get_tensor(key).item() | |
| tensors[key] = f.get_tensor(key) | |
| if hasattr(model, "module"): | |
| result = model.module.load_state_dict(tensors, strict=False) | |
| else: | |
| result = model.load_state_dict(tensors, strict=False) | |
| for key in result.missing_keys: | |
| if key.startswith("enc_q") and for_infer: | |
| continue | |
| logger.warning(f"Missing key: {key}") | |
| for key in result.unexpected_keys: | |
| if key == "iteration": | |
| continue | |
| logger.warning(f"Unexpected key: {key}") | |
| if iteration is None: | |
| logger.info(f"Loaded '{checkpoint_path}'") | |
| else: | |
| logger.info(f"Loaded '{checkpoint_path}' (iteration {iteration})") | |
| return model, iteration | |
| def save_safetensors( | |
| model: torch.nn.Module, | |
| iteration: int, | |
| checkpoint_path: Union[str, Path], | |
| is_half: bool = False, | |
| for_infer: bool = False, | |
| ) -> None: | |
| """ | |
| モデルを safetensors 形式で保存する。 | |
| Args: | |
| model (torch.nn.Module): 保存するモデル | |
| iteration (int): イテレーション回数 | |
| checkpoint_path (Union[str, Path]): 保存先のパス | |
| is_half (bool): モデルを半精度で保存するかどうかのフラグ | |
| for_infer (bool): 推論用に保存するかどうかのフラグ | |
| """ | |
| if hasattr(model, "module"): | |
| state_dict = model.module.state_dict() | |
| else: | |
| state_dict = model.state_dict() | |
| keys = [] | |
| for k in state_dict: | |
| if "enc_q" in k and for_infer: | |
| continue | |
| keys.append(k) | |
| new_dict = ( | |
| {k: state_dict[k].half() for k in keys} | |
| if is_half | |
| else {k: state_dict[k] for k in keys} | |
| ) | |
| new_dict["iteration"] = torch.LongTensor([iteration]) | |
| logger.info(f"Saved safetensors to {checkpoint_path}") | |
| save_file(new_dict, checkpoint_path) | |