Upload folder using huggingface_hub
Browse files- classification/config.json +30 -0
- classification/configuration_neuroclr.py +59 -0
- classification/export_classification_to_hf.py +109 -0
- classification/model.safetensors +3 -0
- classification/modeling_neuroclr.py +301 -0
- pretraining/config.json +19 -0
- pretraining/configuration_neuroclr.py +31 -0
- pretraining/export_pretraining_to_hf.py +61 -0
- pretraining/model.safetensors +3 -0
- pretraining/modeling_neuroclr.py +79 -0
- upload_to_hf.py +12 -0
classification/config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TSlength": 128,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NeuroCLRForSequenceClassification"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
|
| 8 |
+
"AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification"
|
| 9 |
+
},
|
| 10 |
+
"base_filters": 256,
|
| 11 |
+
"downsample_gap": 6,
|
| 12 |
+
"freeze_encoder": true,
|
| 13 |
+
"groups": 32,
|
| 14 |
+
"increasefilter_gap": 12,
|
| 15 |
+
"kernel_size": 16,
|
| 16 |
+
"model_type": "neuroclr",
|
| 17 |
+
"n_block": 48,
|
| 18 |
+
"n_rois": 200,
|
| 19 |
+
"nhead": 2,
|
| 20 |
+
"nlayer": 2,
|
| 21 |
+
"normalize_input": true,
|
| 22 |
+
"pooling": "flatten",
|
| 23 |
+
"projector_out1": 128,
|
| 24 |
+
"projector_out2": 64,
|
| 25 |
+
"stride": 2,
|
| 26 |
+
"torch_dtype": "float32",
|
| 27 |
+
"transformers_version": "4.36.2",
|
| 28 |
+
"use_bn": true,
|
| 29 |
+
"use_do": true
|
| 30 |
+
}
|
classification/configuration_neuroclr.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# classification/configuration_neuroclr.py
|
| 2 |
+
from transformers import PretrainedConfig
|
| 3 |
+
|
| 4 |
+
class NeuroCLRConfig(PretrainedConfig):
|
| 5 |
+
model_type = "neuroclr"
|
| 6 |
+
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
# Encoder / SSL
|
| 10 |
+
TSlength: int = 128,
|
| 11 |
+
nhead: int = 4,
|
| 12 |
+
nlayer: int = 4,
|
| 13 |
+
projector_out1: int = 256,
|
| 14 |
+
projector_out2: int = 128,
|
| 15 |
+
pooling: str = "flatten", # input is [B,1,128]
|
| 16 |
+
normalize_input: bool = True,
|
| 17 |
+
|
| 18 |
+
# Classification
|
| 19 |
+
n_rois: int = 200,
|
| 20 |
+
num_labels: int = 2,
|
| 21 |
+
|
| 22 |
+
# ResNet1D head hyperparams
|
| 23 |
+
base_filters: int = 256,
|
| 24 |
+
kernel_size: int = 16,
|
| 25 |
+
stride: int = 2,
|
| 26 |
+
groups: int = 32,
|
| 27 |
+
n_block: int = 48,
|
| 28 |
+
downsample_gap: int = 6,
|
| 29 |
+
increasefilter_gap: int = 12,
|
| 30 |
+
use_bn: bool = True,
|
| 31 |
+
use_do: bool = True,
|
| 32 |
+
|
| 33 |
+
**kwargs
|
| 34 |
+
):
|
| 35 |
+
super().__init__(**kwargs)
|
| 36 |
+
|
| 37 |
+
# Encoder
|
| 38 |
+
self.TSlength = TSlength
|
| 39 |
+
self.nhead = nhead
|
| 40 |
+
self.nlayer = nlayer
|
| 41 |
+
self.projector_out1 = projector_out1
|
| 42 |
+
self.projector_out2 = projector_out2
|
| 43 |
+
self.pooling = pooling
|
| 44 |
+
self.normalize_input = normalize_input
|
| 45 |
+
|
| 46 |
+
# Classification
|
| 47 |
+
self.n_rois = n_rois
|
| 48 |
+
self.num_labels = num_labels
|
| 49 |
+
|
| 50 |
+
# ResNet1D head
|
| 51 |
+
self.base_filters = base_filters
|
| 52 |
+
self.kernel_size = kernel_size
|
| 53 |
+
self.stride = stride
|
| 54 |
+
self.groups = groups
|
| 55 |
+
self.n_block = n_block
|
| 56 |
+
self.downsample_gap = downsample_gap
|
| 57 |
+
self.increasefilter_gap = increasefilter_gap
|
| 58 |
+
self.use_bn = use_bn
|
| 59 |
+
self.use_do = use_do
|
classification/export_classification_to_hf.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from configuration_neuroclr import NeuroCLRConfig
|
| 3 |
+
from modeling_neuroclr import NeuroCLRForSequenceClassification
|
| 4 |
+
|
| 5 |
+
# -------- EDIT THESE PATHS + nhead if needed ----------
|
| 6 |
+
PRETRAIN_CKPT = ""
|
| 7 |
+
HEAD_CKPT = ""
|
| 8 |
+
OUT_DIR = "."
|
| 9 |
+
|
| 10 |
+
CFG = dict(
|
| 11 |
+
# encoder MUST match the pretrained export
|
| 12 |
+
TSlength=128,
|
| 13 |
+
nhead=2, # change if needed
|
| 14 |
+
nlayer=2, # we confirmed this from your pretraining ckpt
|
| 15 |
+
projector_out1=128,
|
| 16 |
+
projector_out2=64,
|
| 17 |
+
pooling="flatten",
|
| 18 |
+
normalize_input=True,
|
| 19 |
+
|
| 20 |
+
# classification
|
| 21 |
+
n_rois=200,
|
| 22 |
+
num_labels=2,
|
| 23 |
+
freeze_encoder=True, # encoder frozen by default
|
| 24 |
+
|
| 25 |
+
# ResNet1D head (your exact settings)
|
| 26 |
+
base_filters=256,
|
| 27 |
+
kernel_size=16,
|
| 28 |
+
stride=2,
|
| 29 |
+
groups=32,
|
| 30 |
+
n_block=48,
|
| 31 |
+
downsample_gap=6,
|
| 32 |
+
increasefilter_gap=12,
|
| 33 |
+
use_bn=True,
|
| 34 |
+
use_do=True,
|
| 35 |
+
)
|
| 36 |
+
# -----------------------------------------------------
|
| 37 |
+
|
| 38 |
+
def load_model_state_dict(path):
|
| 39 |
+
ckpt = torch.load(path, map_location="cpu")
|
| 40 |
+
if isinstance(ckpt, dict):
|
| 41 |
+
if "model_state_dict" in ckpt:
|
| 42 |
+
return ckpt["model_state_dict"]
|
| 43 |
+
if "state_dict" in ckpt:
|
| 44 |
+
return ckpt["state_dict"]
|
| 45 |
+
return ckpt
|
| 46 |
+
return ckpt
|
| 47 |
+
|
| 48 |
+
def remap_encoder(sd):
|
| 49 |
+
# pretraining ckpt keys: transformer_encoder.* and projector.*
|
| 50 |
+
new = {}
|
| 51 |
+
for k, v in sd.items():
|
| 52 |
+
k2 = k.replace("module.", "")
|
| 53 |
+
if k2.startswith("transformer_encoder.") or k2.startswith("projector."):
|
| 54 |
+
new["encoder." + k2] = v
|
| 55 |
+
return new
|
| 56 |
+
|
| 57 |
+
def remap_head(sd):
|
| 58 |
+
# head ckpt keys likely start with first_block_conv.*, basicblock_list.*, dense.* etc.
|
| 59 |
+
new = {}
|
| 60 |
+
for k, v in sd.items():
|
| 61 |
+
k2 = k.replace("module.", "")
|
| 62 |
+
|
| 63 |
+
head_prefixes = (
|
| 64 |
+
"first_block_conv.", "first_block_bn.", "first_block_relu.",
|
| 65 |
+
"basicblock_list.", "final_bn.", "final_relu.", "dense."
|
| 66 |
+
)
|
| 67 |
+
if k2.startswith(head_prefixes):
|
| 68 |
+
new["head." + k2] = v
|
| 69 |
+
|
| 70 |
+
# If your checkpoint already has head.* then keep it
|
| 71 |
+
elif k2.startswith("head."):
|
| 72 |
+
new[k2] = v
|
| 73 |
+
|
| 74 |
+
return new
|
| 75 |
+
|
| 76 |
+
def main():
|
| 77 |
+
config = NeuroCLRConfig(**CFG)
|
| 78 |
+
|
| 79 |
+
# Enables HF auto-classes loading from this folder
|
| 80 |
+
config.auto_map = {
|
| 81 |
+
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
|
| 82 |
+
"AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification",
|
| 83 |
+
}
|
| 84 |
+
|
| 85 |
+
model = NeuroCLRForSequenceClassification(config)
|
| 86 |
+
|
| 87 |
+
# 1) Load encoder weights from pretraining ckpt
|
| 88 |
+
enc_sd_raw = load_model_state_dict(PRETRAIN_CKPT)
|
| 89 |
+
enc_sd = remap_encoder(enc_sd_raw)
|
| 90 |
+
|
| 91 |
+
# 2) Load head weights from classification ckpt
|
| 92 |
+
head_sd_raw = load_model_state_dict(HEAD_CKPT)
|
| 93 |
+
head_sd = remap_head(head_sd_raw)
|
| 94 |
+
|
| 95 |
+
# 3) Merge and load
|
| 96 |
+
merged = {}
|
| 97 |
+
merged.update(enc_sd)
|
| 98 |
+
merged.update(head_sd)
|
| 99 |
+
|
| 100 |
+
missing, unexpected = model.load_state_dict(merged, strict=False)
|
| 101 |
+
print("Missing:", missing)
|
| 102 |
+
print("Unexpected:", unexpected)
|
| 103 |
+
|
| 104 |
+
# Save to HF folder
|
| 105 |
+
model.save_pretrained(OUT_DIR, safe_serialization=True)
|
| 106 |
+
print("Saved HF classification model to:", OUT_DIR)
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
main()
|
classification/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75a30b12cd8b5d195b93b305693b83543dcf8b758d5a0fe5aec8e5e968c777fe
|
| 3 |
+
size 268265544
|
classification/modeling_neuroclr.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
| 5 |
+
|
| 6 |
+
from transformers import PreTrainedModel
|
| 7 |
+
from configuration_neuroclr import NeuroCLRConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# --------------------------
|
| 11 |
+
# SSL Encoder (per-ROI)
|
| 12 |
+
# --------------------------
|
| 13 |
+
class NeuroCLR(nn.Module):
|
| 14 |
+
def __init__(self, config: NeuroCLRConfig):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
encoder_layer = TransformerEncoderLayer(
|
| 18 |
+
d_model=config.TSlength,
|
| 19 |
+
dim_feedforward=2 * config.TSlength,
|
| 20 |
+
nhead=config.nhead,
|
| 21 |
+
batch_first=True,
|
| 22 |
+
)
|
| 23 |
+
self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer)
|
| 24 |
+
|
| 25 |
+
self.projector = nn.Sequential(
|
| 26 |
+
nn.Linear(config.TSlength, config.projector_out1),
|
| 27 |
+
nn.BatchNorm1d(config.projector_out1),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Linear(config.projector_out1, config.projector_out2),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.normalize_input = config.normalize_input
|
| 33 |
+
self.pooling = config.pooling
|
| 34 |
+
self.TSlength = config.TSlength
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
# x: [B, 1, 128]
|
| 38 |
+
if self.normalize_input:
|
| 39 |
+
x = F.normalize(x, dim=-1)
|
| 40 |
+
|
| 41 |
+
x = self.transformer_encoder(x) # [B, 1, 128]
|
| 42 |
+
|
| 43 |
+
if self.pooling == "flatten":
|
| 44 |
+
h = x.reshape(x.shape[0], -1) # [B, 128]
|
| 45 |
+
elif self.pooling == "mean":
|
| 46 |
+
h = x.mean(dim=1)
|
| 47 |
+
elif self.pooling == "last":
|
| 48 |
+
h = x[:, -1, :]
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f"Unknown pooling='{self.pooling}'")
|
| 51 |
+
|
| 52 |
+
if h.shape[1] != self.TSlength:
|
| 53 |
+
raise ValueError(f"h dim {h.shape[1]} != TSlength {self.TSlength}")
|
| 54 |
+
|
| 55 |
+
z = self.projector(h)
|
| 56 |
+
|
| 57 |
+
return h, z
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# --------------------------
|
| 61 |
+
# Your ResNet1D head (verbatim)
|
| 62 |
+
# --------------------------
|
| 63 |
+
class MyConv1dPadSame(nn.Module):
|
| 64 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
|
| 65 |
+
super().__init__()
|
| 66 |
+
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups)
|
| 67 |
+
|
| 68 |
+
self.kernel_size = kernel_size
|
| 69 |
+
self.stride = stride
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
in_dim = x.shape[-1]
|
| 73 |
+
out_dim = (in_dim + self.stride - 1) // self.stride
|
| 74 |
+
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
|
| 75 |
+
pad_left = p // 2
|
| 76 |
+
pad_right = p - pad_left
|
| 77 |
+
x = F.pad(x, (pad_left, pad_right), "constant", 0)
|
| 78 |
+
return self.conv(x)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class MyMaxPool1dPadSame(nn.Module):
|
| 82 |
+
def __init__(self, kernel_size):
|
| 83 |
+
super().__init__()
|
| 84 |
+
self.kernel_size = kernel_size
|
| 85 |
+
self.stride = 1
|
| 86 |
+
self.max_pool = nn.MaxPool1d(kernel_size=kernel_size)
|
| 87 |
+
|
| 88 |
+
def forward(self, x):
|
| 89 |
+
in_dim = x.shape[-1]
|
| 90 |
+
out_dim = (in_dim + self.stride - 1) // self.stride
|
| 91 |
+
p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
|
| 92 |
+
pad_left = p // 2
|
| 93 |
+
pad_right = p - pad_left
|
| 94 |
+
x = F.pad(x, (pad_left, pad_right), "constant", 0)
|
| 95 |
+
return self.max_pool(x)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class BasicBlock(nn.Module):
|
| 99 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False):
|
| 100 |
+
super().__init__()
|
| 101 |
+
|
| 102 |
+
self.in_channels = in_channels
|
| 103 |
+
self.out_channels = out_channels
|
| 104 |
+
self.downsample = downsample
|
| 105 |
+
self.use_bn = use_bn
|
| 106 |
+
self.use_do = use_do
|
| 107 |
+
self.is_first_block = is_first_block
|
| 108 |
+
|
| 109 |
+
conv_stride = stride if downsample else 1
|
| 110 |
+
|
| 111 |
+
self.bn1 = nn.BatchNorm1d(in_channels)
|
| 112 |
+
self.relu1 = nn.ReLU()
|
| 113 |
+
self.do1 = nn.Dropout(p=0.75)
|
| 114 |
+
self.conv1 = MyConv1dPadSame(in_channels, out_channels, kernel_size, stride=conv_stride, groups=groups)
|
| 115 |
+
|
| 116 |
+
self.bn2 = nn.BatchNorm1d(out_channels)
|
| 117 |
+
self.relu2 = nn.ReLU()
|
| 118 |
+
self.do2 = nn.Dropout(p=0.75)
|
| 119 |
+
self.conv2 = MyConv1dPadSame(out_channels, out_channels, kernel_size, stride=1, groups=groups)
|
| 120 |
+
|
| 121 |
+
self.max_pool = MyMaxPool1dPadSame(kernel_size=conv_stride)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
identity = x
|
| 125 |
+
|
| 126 |
+
out = x
|
| 127 |
+
if not self.is_first_block:
|
| 128 |
+
if self.use_bn:
|
| 129 |
+
out = self.bn1(out)
|
| 130 |
+
out = self.relu1(out)
|
| 131 |
+
if self.use_do:
|
| 132 |
+
out = self.do1(out)
|
| 133 |
+
out = self.conv1(out)
|
| 134 |
+
|
| 135 |
+
if self.use_bn:
|
| 136 |
+
out = self.bn2(out)
|
| 137 |
+
out = self.relu2(out)
|
| 138 |
+
if self.use_do:
|
| 139 |
+
out = self.do2(out)
|
| 140 |
+
out = self.conv2(out)
|
| 141 |
+
|
| 142 |
+
if self.downsample:
|
| 143 |
+
identity = self.max_pool(identity)
|
| 144 |
+
|
| 145 |
+
if self.out_channels != self.in_channels:
|
| 146 |
+
identity = identity.transpose(-1, -2)
|
| 147 |
+
ch1 = (self.out_channels - self.in_channels) // 2
|
| 148 |
+
ch2 = self.out_channels - self.in_channels - ch1
|
| 149 |
+
identity = F.pad(identity, (ch1, ch2), "constant", 0)
|
| 150 |
+
identity = identity.transpose(-1, -2)
|
| 151 |
+
|
| 152 |
+
out += identity
|
| 153 |
+
return out
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class ResNet1D(nn.Module):
|
| 157 |
+
def __init__(
|
| 158 |
+
self,
|
| 159 |
+
in_channels,
|
| 160 |
+
base_filters,
|
| 161 |
+
kernel_size,
|
| 162 |
+
stride,
|
| 163 |
+
groups,
|
| 164 |
+
n_block,
|
| 165 |
+
n_classes,
|
| 166 |
+
downsample_gap=2,
|
| 167 |
+
increasefilter_gap=4,
|
| 168 |
+
use_bn=True,
|
| 169 |
+
use_do=True,
|
| 170 |
+
verbose=False
|
| 171 |
+
):
|
| 172 |
+
super().__init__()
|
| 173 |
+
self.verbose = verbose
|
| 174 |
+
self.n_block = n_block
|
| 175 |
+
self.kernel_size = kernel_size
|
| 176 |
+
self.stride = stride
|
| 177 |
+
self.groups = groups
|
| 178 |
+
self.use_bn = use_bn
|
| 179 |
+
self.use_do = use_do
|
| 180 |
+
self.downsample_gap = downsample_gap
|
| 181 |
+
self.increasefilter_gap = increasefilter_gap
|
| 182 |
+
|
| 183 |
+
self.first_block_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size=self.kernel_size, stride=1)
|
| 184 |
+
self.first_block_bn = nn.BatchNorm1d(base_filters)
|
| 185 |
+
self.first_block_relu = nn.ReLU()
|
| 186 |
+
out_channels = base_filters
|
| 187 |
+
|
| 188 |
+
self.basicblock_list = nn.ModuleList()
|
| 189 |
+
for i_block in range(self.n_block):
|
| 190 |
+
is_first_block = (i_block == 0)
|
| 191 |
+
downsample = (i_block % self.downsample_gap == 1)
|
| 192 |
+
|
| 193 |
+
if is_first_block:
|
| 194 |
+
in_ch = base_filters
|
| 195 |
+
out_ch = in_ch
|
| 196 |
+
else:
|
| 197 |
+
in_ch = int(base_filters * 2 ** ((i_block - 1) // self.increasefilter_gap))
|
| 198 |
+
if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
|
| 199 |
+
out_ch = in_ch * 2
|
| 200 |
+
else:
|
| 201 |
+
out_ch = in_ch
|
| 202 |
+
|
| 203 |
+
block = BasicBlock(
|
| 204 |
+
in_channels=in_ch,
|
| 205 |
+
out_channels=out_ch,
|
| 206 |
+
kernel_size=self.kernel_size,
|
| 207 |
+
stride=self.stride,
|
| 208 |
+
groups=self.groups,
|
| 209 |
+
downsample=downsample,
|
| 210 |
+
use_bn=self.use_bn,
|
| 211 |
+
use_do=self.use_do,
|
| 212 |
+
is_first_block=is_first_block,
|
| 213 |
+
)
|
| 214 |
+
self.basicblock_list.append(block)
|
| 215 |
+
out_channels = out_ch
|
| 216 |
+
|
| 217 |
+
self.final_bn = nn.BatchNorm1d(out_channels)
|
| 218 |
+
self.final_relu = nn.ReLU(inplace=True)
|
| 219 |
+
self.dense = nn.Linear(out_channels, n_classes)
|
| 220 |
+
|
| 221 |
+
def forward(self, x):
|
| 222 |
+
out = self.first_block_conv(x)
|
| 223 |
+
if self.use_bn:
|
| 224 |
+
out = self.first_block_bn(out)
|
| 225 |
+
out = self.first_block_relu(out)
|
| 226 |
+
|
| 227 |
+
for block in self.basicblock_list:
|
| 228 |
+
out = block(out)
|
| 229 |
+
|
| 230 |
+
if self.use_bn:
|
| 231 |
+
out = self.final_bn(out)
|
| 232 |
+
out = self.final_relu(out)
|
| 233 |
+
out = out.mean(-1)
|
| 234 |
+
out = self.dense(out)
|
| 235 |
+
return out
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# --------------------------
|
| 239 |
+
# HF model: encoder + ResNet1D head
|
| 240 |
+
# --------------------------
|
| 241 |
+
class NeuroCLRForSequenceClassification(PreTrainedModel):
|
| 242 |
+
"""
|
| 243 |
+
Expected input x: [B, 200, 128]
|
| 244 |
+
- runs encoder per ROI: [B,1,128] -> h_r [B,128]
|
| 245 |
+
- stacks into H: [B,200,128]
|
| 246 |
+
- feeds ResNet1D: [B,200,128] -> logits
|
| 247 |
+
"""
|
| 248 |
+
config_class = NeuroCLRConfig
|
| 249 |
+
base_model_prefix = "neuroclr"
|
| 250 |
+
|
| 251 |
+
def __init__(self, config: NeuroCLRConfig):
|
| 252 |
+
super().__init__(config)
|
| 253 |
+
|
| 254 |
+
self.encoder = NeuroCLR(config)
|
| 255 |
+
|
| 256 |
+
# Freeze the encoder
|
| 257 |
+
for p in self.encoder.parameters():
|
| 258 |
+
p.requires_grad = False
|
| 259 |
+
|
| 260 |
+
self.head = ResNet1D(
|
| 261 |
+
in_channels=config.n_rois,
|
| 262 |
+
base_filters=config.base_filters,
|
| 263 |
+
kernel_size=config.kernel_size,
|
| 264 |
+
stride=config.stride,
|
| 265 |
+
groups=config.groups,
|
| 266 |
+
n_block=config.n_block,
|
| 267 |
+
n_classes=config.num_labels,
|
| 268 |
+
downsample_gap=config.downsample_gap,
|
| 269 |
+
increasefilter_gap=config.increasefilter_gap,
|
| 270 |
+
use_bn=config.use_bn,
|
| 271 |
+
use_do=config.use_do,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
self.post_init()
|
| 275 |
+
|
| 276 |
+
def forward(self, x: torch.Tensor, labels: torch.Tensor = None, **kwargs):
|
| 277 |
+
# x: [B, 200, 128]
|
| 278 |
+
if x.ndim != 3 or x.shape[1] != self.config.n_rois or x.shape[2] != self.config.TSlength:
|
| 279 |
+
raise ValueError(
|
| 280 |
+
f"Expected x shape [B,{self.config.n_rois},{self.config.TSlength}] but got {tuple(x.shape)}"
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
B, R, L = x.shape
|
| 284 |
+
|
| 285 |
+
# Encode each ROI independently (ROI-wise SSL)
|
| 286 |
+
hs = []
|
| 287 |
+
for r in range(R):
|
| 288 |
+
xr = x[:, r, :].unsqueeze(1) # [B,1,128]
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
h, _ = self.encoder(xr)
|
| 291 |
+
# h, _ = self.encoder(xr) # h: [B,128]
|
| 292 |
+
hs.append(h.unsqueeze(1)) # [B,1,128]
|
| 293 |
+
|
| 294 |
+
H = torch.cat(hs, dim=1) # [B,200,128]
|
| 295 |
+
|
| 296 |
+
logits = self.head(H) # head expects [B,200,128]
|
| 297 |
+
loss = None
|
| 298 |
+
if labels is not None:
|
| 299 |
+
loss = nn.CrossEntropyLoss()(logits, labels)
|
| 300 |
+
|
| 301 |
+
return {"loss": loss, "logits": logits}
|
pretraining/config.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"TSlength": 128,
|
| 3 |
+
"architectures": [
|
| 4 |
+
"NeuroCLRModel"
|
| 5 |
+
],
|
| 6 |
+
"auto_map": {
|
| 7 |
+
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
|
| 8 |
+
"AutoModel": "modeling_neuroclr.NeuroCLRModel"
|
| 9 |
+
},
|
| 10 |
+
"model_type": "neuroclr",
|
| 11 |
+
"nhead": 2,
|
| 12 |
+
"nlayer": 2,
|
| 13 |
+
"normalize_input": true,
|
| 14 |
+
"pooling": "flatten",
|
| 15 |
+
"projector_out1": 128,
|
| 16 |
+
"projector_out2": 64,
|
| 17 |
+
"torch_dtype": "float32",
|
| 18 |
+
"transformers_version": "4.36.2"
|
| 19 |
+
}
|
pretraining/configuration_neuroclr.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import PretrainedConfig
|
| 2 |
+
|
| 3 |
+
class NeuroCLRConfig(PretrainedConfig):
|
| 4 |
+
model_type = "neuroclr"
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
TSlength: int = 128,
|
| 9 |
+
nhead: int = 2,
|
| 10 |
+
nlayer: int = 2,
|
| 11 |
+
projector_out1: int = 128,
|
| 12 |
+
projector_out2: int = 64,
|
| 13 |
+
|
| 14 |
+
# classification
|
| 15 |
+
num_labels: int = 2,
|
| 16 |
+
|
| 17 |
+
# pooling to avoid flatten dimension mismatch
|
| 18 |
+
pooling: str = "flatten", # "mean" recommended; "flatten" only if seq_len==1
|
| 19 |
+
|
| 20 |
+
normalize_input: bool = True,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__(**kwargs)
|
| 24 |
+
self.TSlength = TSlength
|
| 25 |
+
self.nhead = nhead
|
| 26 |
+
self.nlayer = nlayer
|
| 27 |
+
self.projector_out1 = projector_out1
|
| 28 |
+
self.projector_out2 = projector_out2
|
| 29 |
+
self.num_labels = num_labels
|
| 30 |
+
self.pooling = pooling
|
| 31 |
+
self.normalize_input = normalize_input
|
pretraining/export_pretraining_to_hf.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from configuration_neuroclr import NeuroCLRConfig
|
| 3 |
+
from modeling_neuroclr import NeuroCLRModel
|
| 4 |
+
|
| 5 |
+
# ---- EDIT these to match your training ----
|
| 6 |
+
CFG = dict(
|
| 7 |
+
TSlength=128,
|
| 8 |
+
nhead=2,
|
| 9 |
+
nlayer=2,
|
| 10 |
+
projector_out1=128,
|
| 11 |
+
projector_out2=64,
|
| 12 |
+
pooling="flatten", # because input is [B,1,128]
|
| 13 |
+
normalize_input=True,
|
| 14 |
+
)
|
| 15 |
+
CKPT_PATH = ""
|
| 16 |
+
OUT_DIR = "." # saves into pretraining/ folder
|
| 17 |
+
# ------------------------------------------
|
| 18 |
+
|
| 19 |
+
def remap_state_dict(sd):
|
| 20 |
+
new_sd = {}
|
| 21 |
+
for k, v in sd.items():
|
| 22 |
+
k2 = k.replace("module.", "") # if DDP ever used
|
| 23 |
+
if k2.startswith("transformer_encoder.") or k2.startswith("projector."):
|
| 24 |
+
new_sd["neuroclr." + k2] = v
|
| 25 |
+
else:
|
| 26 |
+
# keep anything else as-is (usually none)
|
| 27 |
+
new_sd[k2] = v
|
| 28 |
+
return new_sd
|
| 29 |
+
|
| 30 |
+
def main():
|
| 31 |
+
config = NeuroCLRConfig(**CFG)
|
| 32 |
+
|
| 33 |
+
# This enables AutoModel loading from this folder
|
| 34 |
+
config.auto_map = {
|
| 35 |
+
"AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
|
| 36 |
+
"AutoModel": "modeling_neuroclr.NeuroCLRModel",
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
model = NeuroCLRModel(config)
|
| 40 |
+
|
| 41 |
+
ckpt = torch.load(CKPT_PATH, map_location="cpu")
|
| 42 |
+
|
| 43 |
+
# Your checkpoint uses model_state_dict
|
| 44 |
+
if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
|
| 45 |
+
sd = ckpt["model_state_dict"]
|
| 46 |
+
elif isinstance(ckpt, dict) and "state_dict" in ckpt:
|
| 47 |
+
sd = ckpt["state_dict"]
|
| 48 |
+
else:
|
| 49 |
+
sd = ckpt
|
| 50 |
+
|
| 51 |
+
sd = remap_state_dict(sd)
|
| 52 |
+
|
| 53 |
+
missing, unexpected = model.load_state_dict(sd, strict=False)
|
| 54 |
+
print("Missing:", missing)
|
| 55 |
+
print("Unexpected:", unexpected)
|
| 56 |
+
|
| 57 |
+
model.save_pretrained(OUT_DIR, safe_serialization=True)
|
| 58 |
+
print("Saved HF pretraining model to:", OUT_DIR)
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
main()
|
pretraining/model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f2a85ac990c09ae2debb3796dd0161d7c8f7c14213e62fb917c481f35296279
|
| 3 |
+
size 1164680
|
pretraining/modeling_neuroclr.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
| 5 |
+
|
| 6 |
+
from transformers import PreTrainedModel
|
| 7 |
+
from configuration_neuroclr import NeuroCLRConfig
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class NeuroCLR(nn.Module):
|
| 11 |
+
"""
|
| 12 |
+
Transformer expects x: [B, S, TSlength] because d_model = TSlength.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, config: NeuroCLRConfig):
|
| 15 |
+
super().__init__()
|
| 16 |
+
|
| 17 |
+
encoder_layer = TransformerEncoderLayer(
|
| 18 |
+
d_model=config.TSlength,
|
| 19 |
+
dim_feedforward=2 * config.TSlength,
|
| 20 |
+
nhead=config.nhead,
|
| 21 |
+
batch_first=True,
|
| 22 |
+
)
|
| 23 |
+
self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer)
|
| 24 |
+
|
| 25 |
+
self.projector = nn.Sequential(
|
| 26 |
+
nn.Linear(config.TSlength, config.projector_out1),
|
| 27 |
+
nn.BatchNorm1d(config.projector_out1),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
nn.Linear(config.projector_out1, config.projector_out2),
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
self.normalize_input = config.normalize_input
|
| 33 |
+
self.pooling = config.pooling
|
| 34 |
+
self.TSlength = config.TSlength
|
| 35 |
+
|
| 36 |
+
def forward(self, x: torch.Tensor):
|
| 37 |
+
# x: [B, S, TSlength]
|
| 38 |
+
if self.normalize_input:
|
| 39 |
+
x = F.normalize(x, dim=-1)
|
| 40 |
+
|
| 41 |
+
x = self.transformer_encoder(x) # [B, S, TSlength]
|
| 42 |
+
|
| 43 |
+
# Make h shape always [B, TSlength]
|
| 44 |
+
if self.pooling == "mean":
|
| 45 |
+
h = x.mean(dim=1) # [B, TSlength]
|
| 46 |
+
elif self.pooling == "last":
|
| 47 |
+
h = x[:, -1, :] # [B, TSlength]
|
| 48 |
+
elif self.pooling == "flatten":
|
| 49 |
+
# ONLY valid if S == 1
|
| 50 |
+
h = x.reshape(x.shape[0], -1)
|
| 51 |
+
if h.shape[1] != self.TSlength:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
f"pooling='flatten' requires seq_len==1 so h dim == TSlength. "
|
| 54 |
+
f"Got h dim {h.shape[1]} vs TSlength {self.TSlength}."
|
| 55 |
+
)
|
| 56 |
+
else:
|
| 57 |
+
raise ValueError(f"Unknown pooling='{self.pooling}'. Use 'mean', 'last', or 'flatten'.")
|
| 58 |
+
|
| 59 |
+
z = self.projector(h)
|
| 60 |
+
|
| 61 |
+
return h, z
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class NeuroCLRModel(PreTrainedModel):
|
| 65 |
+
"""
|
| 66 |
+
Loads with:
|
| 67 |
+
AutoModel.from_pretrained(..., trust_remote_code=True)
|
| 68 |
+
"""
|
| 69 |
+
config_class = NeuroCLRConfig
|
| 70 |
+
base_model_prefix = "neuroclr"
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: NeuroCLRConfig):
|
| 73 |
+
super().__init__(config)
|
| 74 |
+
self.neuroclr = NeuroCLR(config)
|
| 75 |
+
self.post_init()
|
| 76 |
+
|
| 77 |
+
def forward(self, x: torch.Tensor, **kwargs):
|
| 78 |
+
h, z = self.neuroclr(x)
|
| 79 |
+
return {"h": h, "z": z}
|
upload_to_hf.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from huggingface_hub import create_repo, upload_folder
|
| 2 |
+
|
| 3 |
+
REPO_ID = "SaeedLab/NeuroCLR"
|
| 4 |
+
# create_repo(REPO_ID, repo_type="model", exist_ok=True)
|
| 5 |
+
|
| 6 |
+
upload_folder(
|
| 7 |
+
repo_id=REPO_ID,
|
| 8 |
+
repo_type="model",
|
| 9 |
+
folder_path=".", # uploads pretraining/ and classification/
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
print("Uploaded to:", REPO_ID)
|