pong / src /config.py
chrisxx's picture
Add Neural Pong application files
8746765
from dataclasses import dataclass, field
from typing import List, Optional
import yaml
from omegaconf import OmegaConf
@dataclass
class TransformerConfig:
model_id : str = None
width : int = 24
height : int = 24
T : int = 1000
in_channels : int = 3
n_window : int = 7
patch_size : int = 2
n_heads : int = 4
d_model : int = 64
n_blocks : int = 12
n_heads : int = 12
d_model : int = 384
patch_size : int = 1
bidirectional : bool = True
nocompile : bool = False
checkpoint : str = None
@dataclass
class TrainingConfig:
lr1 : float = 0.002
lr2 : float = 3e-5
betas : tuple = (0.9, 0.95)
weight_decay : float = 1e-5
max_steps : int = 26000
batch_size : int = 32
noclip : bool = False
duration : int = 1
fps : int = 7
in_channels : int = 3
debug : bool = False
@dataclass
class WANDBConfig:
name : str = "toy-wm"
project : str = None
run_name : str = None
@dataclass
class Config:
model: TransformerConfig
train: TrainingConfig
wandb: WANDBConfig
@classmethod
def from_yaml(cls, path):
with open(path) as f:
raw_cfg = yaml.safe_load(f)
cfg = OmegaConf.create(raw_cfg)
return OmegaConf.structured(cls(**cfg))