| import sys |
| import os |
| import argparse |
| import copy |
| import time |
| import json |
|
|
| import torch.nn as nn |
| import wandb |
| from pytorch_lightning import Trainer |
| from pytorch_lightning.loggers import WandbLogger |
| from torchcfm.optimal_transport import OTPlanSampler |
|
|
| from parsers import parse_args |
| from train_utils import load_config, merge_config, generate_group_string, dataset_name2datapath, create_callbacks |
| from src.branchsbm import BranchSBM |
| from src.branch_flow_net_train import FlowNetTrainCell, FlowNetTrainLidar |
| from src.branch_flow_net_test import ( |
| FlowNetTestLidar, FlowNetTestMouse, FlowNetTestClonidine, FlowNetTestTrametinib, FlowNetTestVeres |
| ) |
| from src.branch_interpolant_train import BranchInterpolantTrain |
| from src.branch_growth_net_train import GrowthNetTrain, GrowthNetTrainCell, GrowthNetTrainLidar, SequentialGrowthNetTrain |
| from src.networks.flow_mlp import VelocityNet |
| from src.networks.growth_mlp import GrowthNet |
| from src.networks.interpolant_mlp import GeoPathMLP |
| from src.utils import set_seed |
| from src.ema import EMA |
| from src.geo_metrics.metric_factory import DataManifoldMetric |
| from dataloaders.mouse_data import WeightedBranchedCellDataModule, SingleBranchCellDataModule |
| from dataloaders.three_branch_data import ThreeBranchTahoeDataModule |
| from dataloaders.clonidine_v2_data import ClonidineV2DataModule |
| from dataloaders.clonidine_single_branch import ClonidineSingleBranchDataModule |
| from dataloaders.trametinib_single import TrametinibSingleBranchDataModule |
| from dataloaders.lidar_data import WeightedBranchedLidarDataModule |
| from dataloaders.lidar_data_single import LidarSingleDataModule |
| from dataloaders.veres_leiden_data import WeightedBranchedVeresDataModule |
|
|
| def main(args: argparse.Namespace, seed: int, t_exclude: int) -> None: |
| set_seed(seed) |
| branches = args.branches |
|
|
| skipped_time_points = [t_exclude] if t_exclude else [] |
| print("config path:") |
| print(args.config_path) |
| print("whiten") |
| print(args.whiten) |
| |
| |
| current_datetime = time.strftime("%m_%d_%H%M", time.localtime()) |
| run_name_with_datetime = f"{current_datetime}_{args.run_name}" |
| |
| |
| args.run_name = run_name_with_datetime |
| |
| |
| |
| |
| if args.data_name == "lidar": |
| datamodule = WeightedBranchedLidarDataModule(args=args) |
| elif args.data_name == "lidarsingle": |
| datamodule = LidarSingleDataModule(args=args) |
| elif args.data_name == "mouse": |
| datamodule = WeightedBranchedCellDataModule(args=args) |
| elif args.data_name == "mousesingle": |
| datamodule = SingleBranchCellDataModule(args=args) |
| elif args.data_name in ["clonidine50D", "clonidine100D", "clonidine150D"]: |
| datamodule = ClonidineV2DataModule(args=args) |
| elif args.data_name == "clonidine50Dsingle": |
| datamodule = ClonidineSingleBranchDataModule(args=args) |
| elif args.data_name == "trametinib": |
| datamodule = ThreeBranchTahoeDataModule(args=args) |
| elif args.data_name == "trametinibsingle": |
| datamodule = TrametinibSingleBranchDataModule(args=args) |
| elif args.data_name == "veres": |
| datamodule = WeightedBranchedVeresDataModule(args=args) |
| branches = datamodule.num_branches |
| print("number of branches:", branches) |
| |
| flow_nets = nn.ModuleList() |
| geopath_nets = nn.ModuleList() |
| growth_nets = nn.ModuleList() |
| |
| |
| for i in range(branches): |
| flow_net = VelocityNet( |
| dim=args.dim, |
| hidden_dims=args.hidden_dims_flow, |
| activation=args.activation_flow, |
| batch_norm=False, |
| ) |
| geopath_net = GeoPathMLP( |
| input_dim=args.dim, |
| hidden_dims=args.hidden_dims_geopath, |
| time_geopath=args.time_geopath, |
| activation=args.activation_geopath, |
| batch_norm=False, |
| ) |
| |
| if i == 0: |
| growth_net = GrowthNet( |
| dim=args.dim, |
| hidden_dims=args.hidden_dims_growth, |
| activation=args.activation_growth, |
| batch_norm=False, |
| negative=True |
| ) |
| else: |
| growth_net = GrowthNet( |
| dim=args.dim, |
| hidden_dims=args.hidden_dims_growth, |
| activation=args.activation_growth, |
| batch_norm=False, |
| negative=False |
| ) |
| |
| if args.ema_decay is not None: |
| flow_net = EMA(model=flow_net, decay=args.ema_decay) |
| geopath_net = EMA(model=geopath_net, decay=args.ema_decay) |
| growth_net = EMA(model=growth_net, decay=args.ema_decay) |
| |
| flow_nets.append(flow_net) |
| geopath_nets.append(geopath_net) |
| growth_nets.append(growth_net) |
| |
| |
| ot_sampler = ( |
| OTPlanSampler(method=args.optimal_transport_method) |
| if args.optimal_transport_method != "None" |
| else None |
| ) |
|
|
| wandb.init( |
| project="branchsbm", |
| name=run_name_with_datetime, |
| config=vars(args), |
| dir=args.working_dir, |
| ) |
|
|
| flow_matcher_base = BranchSBM( |
| geopath_nets=geopath_nets, |
| sigma=args.sigma, |
| alpha=int(args.branchsbm), |
| ) |
|
|
| |
| geopath_callbacks = create_callbacks( |
| args, phase="geopath", data_type=args.data_type, run_id=wandb.run.id |
| ) |
| |
| |
| data_manifold_metric = DataManifoldMetric( |
| args=args, |
| skipped_time_points=skipped_time_points, |
| datamodule=datamodule, |
| ) |
| geopath_model = BranchInterpolantTrain( |
| flow_matcher=flow_matcher_base, |
| skipped_time_points=skipped_time_points, |
| ot_sampler=ot_sampler, |
| args=args, |
| data_manifold_metric=data_manifold_metric |
| ) |
| |
| wandb_logger = WandbLogger(version=run_name_with_datetime) |
|
|
| trainer = Trainer( |
| max_epochs=args.epochs, |
| callbacks=geopath_callbacks, |
| accelerator=args.accelerator, |
| logger=wandb_logger, |
| num_sanity_val_steps=0, |
| default_root_dir=args.working_dir, |
| gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| ) |
| |
| if args.load_geopath_model_ckpt: |
| best_model_path = args.load_geopath_model_ckpt |
| else: |
| trainer.fit( |
| geopath_model, |
| datamodule=datamodule, |
| ) |
| |
| best_model_path = geopath_callbacks[0].best_model_path |
| |
| geopath_model = BranchInterpolantTrain.load_from_checkpoint(best_model_path) |
|
|
| flow_matcher_base.geopath_nets = geopath_model.geopath_nets |
|
|
| |
|
|
| |
| flow_callbacks = create_callbacks( |
| args, |
| phase="flow", |
| data_type=args.data_type, |
| run_id=wandb.run.id, |
| datamodule=datamodule, |
| ) |
| |
| if args.data_type == "lidar": |
| FlowNetTrain = FlowNetTrainLidar |
| else: |
| FlowNetTrain = FlowNetTrainCell |
|
|
| flow_train = FlowNetTrain( |
| flow_matcher=flow_matcher_base, |
| flow_nets=flow_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| ) |
|
|
| |
| wandb_logger = WandbLogger(version=run_name_with_datetime) |
|
|
| trainer = Trainer( |
| max_epochs=args.epochs, |
| callbacks=flow_callbacks, |
| check_val_every_n_epoch=args.check_val_every_n_epoch, |
| accelerator=args.accelerator, |
| logger=wandb_logger, |
| default_root_dir=args.working_dir, |
| gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| num_sanity_val_steps=(0 if args.data_type == "image" else None), |
| ) |
|
|
| trainer.fit( |
| flow_train, datamodule=datamodule, ckpt_path=args.resume_flow_model_ckpt |
| ) |
| if args.data_type == "lidar": |
| trainer.test(flow_train, datamodule=datamodule) |
| |
| |
| |
| flow_nets = flow_train.flow_nets |
| |
| growth_callbacks = create_callbacks( |
| args, |
| phase="growth", |
| data_type=args.data_type, |
| run_id=wandb.run.id, |
| datamodule=datamodule, |
| ) |
|
|
| if args.data_type == "lidar": |
| GrowthNetTrainClass = GrowthNetTrainLidar |
| else: |
| GrowthNetTrainClass = GrowthNetTrainCell |
| |
| growth_train = GrowthNetTrainClass( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = False |
| ) |
|
|
| |
| wandb_logger = WandbLogger(version=run_name_with_datetime) |
|
|
| trainer = Trainer( |
| max_epochs=args.epochs, |
| callbacks=growth_callbacks, |
| check_val_every_n_epoch=args.check_val_every_n_epoch, |
| accelerator=args.accelerator, |
| logger=wandb_logger, |
| default_root_dir=args.working_dir, |
| gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| num_sanity_val_steps=(0 if args.data_type == "image" else None), |
| ) |
| |
| trainer.fit( |
| growth_train, datamodule=datamodule, ckpt_path=None |
| ) |
| |
| |
| best_growth_path = growth_callbacks[0].best_model_path |
| if best_growth_path: |
| print(f"Loading best growth model from: {best_growth_path}") |
| if args.sequential: |
| growth_train = SequentialGrowthNetTrain.load_from_checkpoint( |
| best_growth_path, |
| flow_nets=flow_nets, |
| growth_nets=growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint=False |
| ) |
| else: |
| growth_train = GrowthNetTrainClass.load_from_checkpoint( |
| best_growth_path, |
| flow_nets=flow_nets, |
| growth_nets=growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint=False |
| ) |
| |
| flow_nets = growth_train.flow_nets |
| |
| if isinstance(flow_nets, tuple): |
| flow_nets = nn.ModuleList(flow_nets) |
| if isinstance(growth_nets, tuple): |
| growth_nets = nn.ModuleList(growth_nets) |
| |
| |
| if "lidar" in args.data_name.lower(): |
| test_model = FlowNetTestLidar( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = False |
| ) |
| elif "mouse" in args.data_name.lower(): |
| test_model = FlowNetTestMouse( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = False |
| ) |
| elif "clonidine" in args.data_name.lower(): |
| test_model = FlowNetTestClonidine( |
| flow_matcher=flow_matcher_base, |
| flow_nets=flow_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| ) |
| elif "trametinib" in args.data_name.lower(): |
| test_model = FlowNetTestTrametinib( |
| flow_matcher=flow_matcher_base, |
| flow_nets=flow_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| ) |
| elif "veres" in args.data_name.lower(): |
| test_model = FlowNetTestVeres( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = False |
| ) |
| else: |
| |
| test_model = growth_train |
| |
| trainer.test(test_model, datamodule=datamodule) |
| |
| |
| |
| |
| |
| growth_nets = growth_train.growth_nets |
| |
| joint_callbacks = create_callbacks( |
| args, |
| phase="joint", |
| data_type=args.data_type, |
| run_id=wandb.run.id, |
| datamodule=datamodule, |
| ) |
| |
| if args.sequential: |
| joint_train = SequentialGrowthNetTrain( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = True |
| ) |
| else: |
| if args.data_type == "lidar": |
| GrowthNetTrainClass = GrowthNetTrainLidar |
| else: |
| GrowthNetTrainClass = GrowthNetTrainCell |
| |
| joint_train = GrowthNetTrainClass( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = True |
| ) |
| |
| |
| wandb_logger = WandbLogger(version=run_name_with_datetime) |
|
|
| trainer = Trainer( |
| max_epochs=args.epochs, |
| callbacks=joint_callbacks, |
| check_val_every_n_epoch=args.check_val_every_n_epoch, |
| accelerator=args.accelerator, |
| logger=wandb_logger, |
| default_root_dir=args.working_dir, |
| gradient_clip_val=(1.0 if args.data_type == "image" else None), |
| num_sanity_val_steps=(0 if args.data_type == "image" else None), |
| ) |
| |
| trainer.fit( |
| joint_train, datamodule=datamodule, ckpt_path=None |
| ) |
| |
| |
| best_joint_path = joint_callbacks[0].best_model_path |
| if best_joint_path: |
| print(f"Loading best joint model from: {best_joint_path}") |
| if args.sequential: |
| joint_train = SequentialGrowthNetTrain.load_from_checkpoint( |
| best_joint_path, |
| flow_nets=flow_nets, |
| growth_nets=growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint=True |
| ) |
| else: |
| joint_train = GrowthNetTrainClass.load_from_checkpoint( |
| best_joint_path, |
| flow_nets=flow_nets, |
| growth_nets=growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint=True |
| ) |
| |
| flow_nets = joint_train.flow_nets |
| growth_nets = joint_train.growth_nets |
| |
| if isinstance(flow_nets, tuple): |
| flow_nets = nn.ModuleList(flow_nets) |
| if isinstance(growth_nets, tuple): |
| growth_nets = nn.ModuleList(growth_nets) |
| |
| |
| if "lidar" in args.data_name.lower(): |
| test_model = FlowNetTestLidar( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = True |
| ) |
| elif "mouse" in args.data_name.lower(): |
| test_model = FlowNetTestMouse( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = True |
| ) |
| elif "clonidine" in args.data_name.lower(): |
| test_model = FlowNetTestClonidine( |
| flow_matcher=flow_matcher_base, |
| flow_nets=flow_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| ) |
| elif "trametinib" in args.data_name.lower(): |
| test_model = FlowNetTestTrametinib( |
| flow_matcher=flow_matcher_base, |
| flow_nets=flow_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| ) |
| elif "veres" in args.data_name.lower(): |
| test_model = FlowNetTestVeres( |
| flow_nets = flow_nets, |
| growth_nets = growth_nets, |
| ot_sampler=ot_sampler, |
| skipped_time_points=skipped_time_points, |
| args=args, |
| data_manifold_metric=data_manifold_metric, |
| joint = True |
| ) |
| else: |
| test_model = joint_train |
| test_model = joint_train |
| |
| trainer.test(test_model, datamodule=datamodule) |
| |
| |
| |
| wandb.finish() |
| |
| if __name__ == "__main__": |
| args = parse_args() |
| updated_args = copy.deepcopy(args) |
| if args.config_path: |
| config = load_config(args.config_path) |
| updated_args = merge_config(updated_args, config) |
|
|
| updated_args.group_name = generate_group_string() |
| updated_args.data_path = dataset_name2datapath( |
| updated_args.data_name, updated_args.working_dir |
| ) |
| for seed in updated_args.seeds: |
| if updated_args.t_exclude: |
| for i, t_exclude in enumerate(updated_args.t_exclude): |
| updated_args.t_exclude_current = t_exclude |
| updated_args.seed_current = seed |
| updated_args.gamma_current = updated_args.gammas[i] |
| main(updated_args, seed=seed, t_exclude=t_exclude) |
| else: |
| updated_args.seed_current = seed |
| updated_args.gamma_current = updated_args.gammas[0] |
| main(updated_args, seed=seed, t_exclude=None) |
|
|