| import argparse |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--config_path", type=str, |
| default='', |
| help="Path to config file" |
| ) |
| parser.add_argument( |
| "--optimal_transport_method", |
| type=str, |
| default="exact", |
| help="Use optimal transport in CFM training", |
| ) |
| parser.add_argument( |
| "--split_ratios", |
| nargs=2, |
| type=float, |
| default=[0.9, 0.1], |
| help="Split ratios for training/validation data in CFM training", |
| ) |
| parser.add_argument( |
| "--accelerator", type=str, default="cpu", help="Training accelerator" |
| ) |
| parser.add_argument("--date", type=str) |
| parser.add_argument("--seed", default=2, type=int) |
| parser.add_argument("--device", default="cuda:1", type=str) |
| parser.add_argument("--molecule", default="aldp", type=str) |
| parser.add_argument('--wandb', action='store_true', default=False) |
| parser.add_argument('--unseen', action='store_true', default=False) |
| parser.add_argument('--run_name', default=None, type=str) |
| |
| parser.add_argument("--save_dir", default="", type=str) |
| parser.add_argument("--root_dir", default="", type=str) |
| |
| parser.add_argument("--bias", default="force", type=str) |
| |
| parser.add_argument("--start_state", default="c5", type=str) |
| parser.add_argument("--end_state", default="c7ax", type=str) |
| parser.add_argument("--num_steps", default=100, type=int) |
| |
| parser.add_argument("--sigma", default=0.1, type=float) |
| parser.add_argument("--num_samples", default=16, type=int) |
| parser.add_argument("--temperature", default=300, type=float) |
| parser.add_argument("--friction", default=2.0, type=float) |
| parser.add_argument("--rbf", action='store_true', default=False) |
| parser.add_argument("--use_delta_to_target", action='store_true', default=False) |
| parser.add_argument("--use_gnn", action='store_true', default=False) |
| |
| parser.add_argument("--start_temperature", default=600, type=float) |
| parser.add_argument("--end_temperature", default=300, type=float) |
| parser.add_argument("--num_rollouts", default=1000, type=int) |
| parser.add_argument("--trains_per_rollout", default=1000, type=int) |
| parser.add_argument("--log_z_lr", default=1e-3, type=float) |
| parser.add_argument("--policy_lr", default=1e-4, type=float) |
| parser.add_argument("--batch_size", default=64, type=int) |
| parser.add_argument("--buffer_size", default=1000, type=int) |
| parser.add_argument("--max_grad_norm", default=1, type=int) |
| parser.add_argument("--control_variate", default="global", type=str) |
| parser.add_argument("--self_normalize", action='store_true', default=False) |
| |
| parser.add_argument("--objective", default="ce", type=str) |
| parser.add_argument("--vel_conditioned", action='store_true', default=False) |
| parser.add_argument("--dir_only", action='store_true', default=False) |
| |
| |
| parser.add_argument("--num_particles", default=16, type=int) |
| |
| parser.add_argument("--kT", type=float, default=0.0) |
| |
| parser = datasets_parser(parser) |
|
|
| |
| parser = metric_parser(parser) |
|
|
| return parser.parse_args() |
|
|
|
|
| def datasets_parser(parser): |
| parser.add_argument("--dim", type=int, default=50, help="Dimension of data") |
|
|
| parser.add_argument( |
| "--data_type", |
| type=str, |
| default="tahoe", |
| help="Type of data, now wither scrna or one of toys", |
| ) |
| parser.add_argument( |
| "--data_name", |
| type=str, |
| default="tahoe", |
| help="Path to the dataset", |
| ) |
| return parser |
|
|
|
|
| def metric_parser(parser): |
| parser.add_argument( |
| "--n_centers", |
| type=int, |
| default=300, |
| help="Number of centers for RBF network", |
| ) |
| parser.add_argument( |
| "--kappa", |
| type=float, |
| default=1.5, |
| help="Kappa parameter for RBF network", |
| ) |
| parser.add_argument( |
| "--rho", |
| type=float, |
| default=-2.75, |
| help="Rho parameter in Riemanian Velocity Calculation", |
| ) |
| parser.add_argument( |
| "--velocity_metric", |
| type=str, |
| default="rbf", |
| help="Metric for velocity calculation", |
| ) |
| parser.add_argument( |
| "--gamma", |
| nargs="+", |
| type=float, |
| default=0.2, |
| help="Gamma parameter in Riemanian Velocity Calculation", |
| ) |
| parser.add_argument( |
| "--metric_epochs", |
| type=int, |
| default=200, |
| help="Number of epochs for metric learning", |
| ) |
| parser.add_argument( |
| "--metric_patience", |
| type=int, |
| default=25, |
| help="Patience for metric learning", |
| ) |
| parser.add_argument( |
| "--metric_lr", |
| type=float, |
| default=1e-2, |
| help="Learning rate for metric learning", |
| ) |
| parser.add_argument( |
| "--alpha_metric", |
| type=float, |
| default=1.0, |
| help="Alpha parameter for metric learning", |
| ) |
| return parser |
|
|
|
|