|
|
import os |
|
|
import re |
|
|
import json |
|
|
import time |
|
|
import shutil |
|
|
from pathlib import Path |
|
|
from tempfile import NamedTemporaryFile |
|
|
from typing import Optional, Dict, Any, List |
|
|
|
|
|
import torch as t |
|
|
from torch import nn |
|
|
|
|
|
from ..models.dit_dforce import get_model |
|
|
from ..config import Config |
|
|
|
|
|
import yaml |
|
|
|
|
|
|
|
|
def load_model_from_config(config_path: str, checkpoint_path: str = None, strict: bool = True) -> nn.Module: |
|
|
print(f"loading {config_path}") |
|
|
cmodel = Config.from_yaml(config_path).model |
|
|
model = get_model(cmodel.height, cmodel.width, |
|
|
n_window=cmodel.n_window, |
|
|
patch_size=cmodel.patch_size, |
|
|
n_heads=cmodel.n_heads,d_model=cmodel.d_model, |
|
|
n_blocks=cmodel.n_blocks, |
|
|
T=cmodel.T, |
|
|
in_channels=cmodel.in_channels, |
|
|
bidirectional=cmodel.bidirectional) |
|
|
if checkpoint_path is None and cmodel.checkpoint is not None: |
|
|
checkpoint_path = cmodel.checkpoint |
|
|
if checkpoint_path is not None: |
|
|
state_dict = t.load(checkpoint_path, weights_only=False) |
|
|
if "model" in state_dict: |
|
|
state_dict = state_dict["model"] |
|
|
if "_orig_mod." in list(state_dict.keys())[0]: |
|
|
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items() if k.startswith("_orig_mod.")} |
|
|
model.load_state_dict(state_dict, strict=strict) |
|
|
print('loaded state dict') |
|
|
return model |
|
|
|
|
|
|
|
|
|
|
|
class CheckpointManager: |
|
|
""" |
|
|
Manage top-K checkpoints by a metric. On each save: |
|
|
- Write a new checkpoint atomically |
|
|
- Keep only the top-K files by metric (max or min) |
|
|
- Delete files not in top-K |
|
|
- Maintain a small JSON index for quick reloads |
|
|
Also scans the directory on init to reconstruct state. |
|
|
|
|
|
Filenames are of the form: ckpt-step=<step>-metric=<metric>.pt |
|
|
""" |
|
|
|
|
|
CKPT_PATTERN = re.compile( |
|
|
r"^ckpt-step=(?P<step>\d+)-metric=(?P<metric>[+-]?\d+(?:\.\d+)?(?:e[+-]?\d+)?)\.pt$" |
|
|
) |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
dirpath: str | Path, |
|
|
k: int = 5, |
|
|
mode: str = "max", |
|
|
metric_name: str = "score", |
|
|
is_main_process: bool = True, |
|
|
index_filename: str = "ckpt_index.json", |
|
|
): |
|
|
self.dir = Path(dirpath) |
|
|
self.dir.mkdir(parents=True, exist_ok=True) |
|
|
assert mode in {"max", "min"} |
|
|
self.k = int(k) |
|
|
self.mode = mode |
|
|
self.metric_name = metric_name |
|
|
self.is_main = bool(is_main_process) |
|
|
self.index_path = self.dir / index_filename |
|
|
|
|
|
|
|
|
self.entries: List[Dict[str, Any]] = [] |
|
|
|
|
|
self._load_index() |
|
|
self._scan_and_merge() |
|
|
self._prune_and_persist() |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
def best(self) -> Optional[Dict[str, Any]]: |
|
|
return self.entries[0] if self.entries else None |
|
|
|
|
|
@property |
|
|
def paths(self) -> List[str]: |
|
|
return [e["path"] for e in self.entries] |
|
|
|
|
|
@property |
|
|
def should_save(self) -> bool: |
|
|
"""Use inside DDP loops to gate saving to rank-0 only.""" |
|
|
return self.is_main |
|
|
|
|
|
def save( |
|
|
self, |
|
|
*, |
|
|
metric: float, |
|
|
step: int, |
|
|
model: Optional[nn.Module] = None, |
|
|
optimizer: Optional[t.optim.Optimizer] = None, |
|
|
scheduler: Optional[Any] = None, |
|
|
extra: Optional[Dict[str, Any]] = None, |
|
|
state_dict: Optional[Dict[str, Any]] = None, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Save a checkpoint and keep only top-K by metric. |
|
|
|
|
|
Provide either `state_dict` or a `model` (optionally optimizer/scheduler). |
|
|
The saved file always contains: |
|
|
- 'model', 'optimizer', 'scheduler' (if provided) |
|
|
- 'step', metric_name, 'timestamp', 'manager' |
|
|
Returns info about the saved file and whether it made the top-K. |
|
|
""" |
|
|
if not self.should_save: |
|
|
return {"saved": False, "kept": False, "reason": "not main process"} |
|
|
|
|
|
if state_dict is None: |
|
|
state_dict = {} |
|
|
if model is not None: |
|
|
state_dict["model"] = model.state_dict() |
|
|
if optimizer is not None: |
|
|
state_dict["optimizer"] = optimizer.state_dict() |
|
|
if scheduler is not None: |
|
|
|
|
|
try: |
|
|
state_dict["scheduler"] = scheduler.state_dict() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
ts = time.time() |
|
|
filename = f"ckpt-step={int(step):06d}-metric={float(metric):.8f}.pt" |
|
|
fpath = self.dir / filename |
|
|
|
|
|
|
|
|
payload = { |
|
|
**state_dict, |
|
|
"step": int(step), |
|
|
self.metric_name: float(metric), |
|
|
"timestamp": ts, |
|
|
"manager": { |
|
|
"mode": self.mode, |
|
|
"k": self.k, |
|
|
"metric_name": self.metric_name, |
|
|
"filename": filename, |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
with NamedTemporaryFile(dir=self.dir, delete=False) as tmp: |
|
|
tmp_path = Path(tmp.name) |
|
|
try: |
|
|
t.save(payload, tmp_path) |
|
|
os.replace(tmp_path, fpath) |
|
|
finally: |
|
|
if tmp_path.exists(): |
|
|
try: |
|
|
tmp_path.unlink() |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
new_entry = { |
|
|
"path": str(fpath), |
|
|
"step": int(step), |
|
|
"metric": float(metric), |
|
|
"ts": ts, |
|
|
} |
|
|
self.entries.append(new_entry) |
|
|
kept = self._prune_and_persist() |
|
|
|
|
|
return {"saved": True, "kept": kept, "path": str(fpath), "best": self.best} |
|
|
|
|
|
|
|
|
|
|
|
def _sort_key(self, e: Dict[str, Any]): |
|
|
|
|
|
|
|
|
return ((-e["metric"], e["step"]) if self.mode == "max" else (e["metric"], e["step"])) |
|
|
|
|
|
def _load_index(self): |
|
|
if not self.index_path.exists(): |
|
|
self.entries = [] |
|
|
return |
|
|
try: |
|
|
data = json.loads(self.index_path.read_text()) |
|
|
entries = data.get("entries", []) |
|
|
|
|
|
self.entries = [e for e in entries if Path(e["path"]).exists()] |
|
|
|
|
|
for e in self.entries: |
|
|
e["metric"] = float(e["metric"]) |
|
|
e["step"] = int(e["step"]) |
|
|
e["ts"] = float(e.get("ts", time.time())) |
|
|
except Exception: |
|
|
|
|
|
self.entries = [] |
|
|
|
|
|
def _scan_and_merge(self): |
|
|
"""Scan directory for checkpoint files and merge with current entries.""" |
|
|
seen = {Path(e["path"]).name for e in self.entries} |
|
|
for p in self.dir.glob("ckpt-step=*-metric=*.pt"): |
|
|
name = p.name |
|
|
if name in seen: |
|
|
continue |
|
|
m = self.CKPT_PATTERN.match(name) |
|
|
if not m: |
|
|
continue |
|
|
step = int(m.group("step")) |
|
|
try: |
|
|
metric = float(m.group("metric")) |
|
|
except ValueError: |
|
|
continue |
|
|
self.entries.append( |
|
|
{"path": str(p), "step": step, "metric": metric, "ts": p.stat().st_mtime} |
|
|
) |
|
|
|
|
|
def _prune_and_persist(self) -> bool: |
|
|
"""Sort by metric, keep top-K, delete the rest. Return True if newest file is kept.""" |
|
|
if not self.entries: |
|
|
self._persist_index() |
|
|
return False |
|
|
|
|
|
|
|
|
self.entries.sort(key=self._sort_key) |
|
|
|
|
|
|
|
|
keep = self.entries[: self.k] |
|
|
drop = self.entries[self.k :] |
|
|
|
|
|
keep_paths = {e["path"] for e in keep} |
|
|
newest_path = max(self.entries, key=lambda e: e["ts"])["path"] |
|
|
newest_kept = newest_path in keep_paths |
|
|
|
|
|
|
|
|
for e in drop: |
|
|
try: |
|
|
Path(e["path"]).unlink(missing_ok=True) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
self.entries = keep |
|
|
self._persist_index() |
|
|
return newest_kept |
|
|
|
|
|
def _persist_index(self): |
|
|
data = { |
|
|
"k": self.k, |
|
|
"mode": self.mode, |
|
|
"metric_name": self.metric_name, |
|
|
"entries": self.entries, |
|
|
"updated_at": time.time(), |
|
|
} |
|
|
tmp = self.index_path.with_suffix(".json.tmp") |
|
|
tmp.write_text(json.dumps(data, indent=2)) |
|
|
os.replace(tmp, self.index_path) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
mgr = CheckpointManager("checkpoints", k=5, mode="max", metric_name="val_acc") |
|
|
|
|
|
model = nn.Linear(10, 2) |
|
|
opt = t.optim.AdamW(model.parameters(), lr=1e-3) |
|
|
|
|
|
|
|
|
for epoch in range(10): |
|
|
metric = 0.5 + 0.1 * t.rand(1).item() |
|
|
info = mgr.save(metric=metric, step=epoch, model=model, optimizer=opt) |
|
|
print( |
|
|
f"epoch {epoch:02d} metric={metric:.4f} saved={info['saved']} kept={info['kept']} " |
|
|
f"best_metric={mgr.best['metric'] if mgr.best else None:.4f}" |
|
|
) |
|
|
|
|
|
print("Top-K paths:", mgr.paths) |
|
|
print("Best:", mgr.best) |
|
|
|