falmuqhim commited on
Commit
c319d57
·
verified ·
1 Parent(s): fec9d9d

Upload folder using huggingface_hub

Browse files
classification/config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TSlength": 128,
3
+ "architectures": [
4
+ "NeuroCLRForSequenceClassification"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
8
+ "AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification"
9
+ },
10
+ "base_filters": 256,
11
+ "downsample_gap": 6,
12
+ "freeze_encoder": true,
13
+ "groups": 32,
14
+ "increasefilter_gap": 12,
15
+ "kernel_size": 16,
16
+ "model_type": "neuroclr",
17
+ "n_block": 48,
18
+ "n_rois": 200,
19
+ "nhead": 2,
20
+ "nlayer": 2,
21
+ "normalize_input": true,
22
+ "pooling": "flatten",
23
+ "projector_out1": 128,
24
+ "projector_out2": 64,
25
+ "stride": 2,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.36.2",
28
+ "use_bn": true,
29
+ "use_do": true
30
+ }
classification/configuration_neuroclr.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # classification/configuration_neuroclr.py
2
+ from transformers import PretrainedConfig
3
+
4
+ class NeuroCLRConfig(PretrainedConfig):
5
+ model_type = "neuroclr"
6
+
7
+ def __init__(
8
+ self,
9
+ # Encoder / SSL
10
+ TSlength: int = 128,
11
+ nhead: int = 4,
12
+ nlayer: int = 4,
13
+ projector_out1: int = 256,
14
+ projector_out2: int = 128,
15
+ pooling: str = "flatten", # input is [B,1,128]
16
+ normalize_input: bool = True,
17
+
18
+ # Classification
19
+ n_rois: int = 200,
20
+ num_labels: int = 2,
21
+
22
+ # ResNet1D head hyperparams
23
+ base_filters: int = 256,
24
+ kernel_size: int = 16,
25
+ stride: int = 2,
26
+ groups: int = 32,
27
+ n_block: int = 48,
28
+ downsample_gap: int = 6,
29
+ increasefilter_gap: int = 12,
30
+ use_bn: bool = True,
31
+ use_do: bool = True,
32
+
33
+ **kwargs
34
+ ):
35
+ super().__init__(**kwargs)
36
+
37
+ # Encoder
38
+ self.TSlength = TSlength
39
+ self.nhead = nhead
40
+ self.nlayer = nlayer
41
+ self.projector_out1 = projector_out1
42
+ self.projector_out2 = projector_out2
43
+ self.pooling = pooling
44
+ self.normalize_input = normalize_input
45
+
46
+ # Classification
47
+ self.n_rois = n_rois
48
+ self.num_labels = num_labels
49
+
50
+ # ResNet1D head
51
+ self.base_filters = base_filters
52
+ self.kernel_size = kernel_size
53
+ self.stride = stride
54
+ self.groups = groups
55
+ self.n_block = n_block
56
+ self.downsample_gap = downsample_gap
57
+ self.increasefilter_gap = increasefilter_gap
58
+ self.use_bn = use_bn
59
+ self.use_do = use_do
classification/export_classification_to_hf.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from configuration_neuroclr import NeuroCLRConfig
3
+ from modeling_neuroclr import NeuroCLRForSequenceClassification
4
+
5
+ # -------- EDIT THESE PATHS + nhead if needed ----------
6
+ PRETRAIN_CKPT = ""
7
+ HEAD_CKPT = ""
8
+ OUT_DIR = "."
9
+
10
+ CFG = dict(
11
+ # encoder MUST match the pretrained export
12
+ TSlength=128,
13
+ nhead=2, # change if needed
14
+ nlayer=2, # we confirmed this from your pretraining ckpt
15
+ projector_out1=128,
16
+ projector_out2=64,
17
+ pooling="flatten",
18
+ normalize_input=True,
19
+
20
+ # classification
21
+ n_rois=200,
22
+ num_labels=2,
23
+ freeze_encoder=True, # encoder frozen by default
24
+
25
+ # ResNet1D head (your exact settings)
26
+ base_filters=256,
27
+ kernel_size=16,
28
+ stride=2,
29
+ groups=32,
30
+ n_block=48,
31
+ downsample_gap=6,
32
+ increasefilter_gap=12,
33
+ use_bn=True,
34
+ use_do=True,
35
+ )
36
+ # -----------------------------------------------------
37
+
38
+ def load_model_state_dict(path):
39
+ ckpt = torch.load(path, map_location="cpu")
40
+ if isinstance(ckpt, dict):
41
+ if "model_state_dict" in ckpt:
42
+ return ckpt["model_state_dict"]
43
+ if "state_dict" in ckpt:
44
+ return ckpt["state_dict"]
45
+ return ckpt
46
+ return ckpt
47
+
48
+ def remap_encoder(sd):
49
+ # pretraining ckpt keys: transformer_encoder.* and projector.*
50
+ new = {}
51
+ for k, v in sd.items():
52
+ k2 = k.replace("module.", "")
53
+ if k2.startswith("transformer_encoder.") or k2.startswith("projector."):
54
+ new["encoder." + k2] = v
55
+ return new
56
+
57
+ def remap_head(sd):
58
+ # head ckpt keys likely start with first_block_conv.*, basicblock_list.*, dense.* etc.
59
+ new = {}
60
+ for k, v in sd.items():
61
+ k2 = k.replace("module.", "")
62
+
63
+ head_prefixes = (
64
+ "first_block_conv.", "first_block_bn.", "first_block_relu.",
65
+ "basicblock_list.", "final_bn.", "final_relu.", "dense."
66
+ )
67
+ if k2.startswith(head_prefixes):
68
+ new["head." + k2] = v
69
+
70
+ # If your checkpoint already has head.* then keep it
71
+ elif k2.startswith("head."):
72
+ new[k2] = v
73
+
74
+ return new
75
+
76
+ def main():
77
+ config = NeuroCLRConfig(**CFG)
78
+
79
+ # Enables HF auto-classes loading from this folder
80
+ config.auto_map = {
81
+ "AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
82
+ "AutoModelForSequenceClassification": "modeling_neuroclr.NeuroCLRForSequenceClassification",
83
+ }
84
+
85
+ model = NeuroCLRForSequenceClassification(config)
86
+
87
+ # 1) Load encoder weights from pretraining ckpt
88
+ enc_sd_raw = load_model_state_dict(PRETRAIN_CKPT)
89
+ enc_sd = remap_encoder(enc_sd_raw)
90
+
91
+ # 2) Load head weights from classification ckpt
92
+ head_sd_raw = load_model_state_dict(HEAD_CKPT)
93
+ head_sd = remap_head(head_sd_raw)
94
+
95
+ # 3) Merge and load
96
+ merged = {}
97
+ merged.update(enc_sd)
98
+ merged.update(head_sd)
99
+
100
+ missing, unexpected = model.load_state_dict(merged, strict=False)
101
+ print("Missing:", missing)
102
+ print("Unexpected:", unexpected)
103
+
104
+ # Save to HF folder
105
+ model.save_pretrained(OUT_DIR, safe_serialization=True)
106
+ print("Saved HF classification model to:", OUT_DIR)
107
+
108
+ if __name__ == "__main__":
109
+ main()
classification/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75a30b12cd8b5d195b93b305693b83543dcf8b758d5a0fe5aec8e5e968c777fe
3
+ size 268265544
classification/modeling_neuroclr.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
5
+
6
+ from transformers import PreTrainedModel
7
+ from configuration_neuroclr import NeuroCLRConfig
8
+
9
+
10
+ # --------------------------
11
+ # SSL Encoder (per-ROI)
12
+ # --------------------------
13
+ class NeuroCLR(nn.Module):
14
+ def __init__(self, config: NeuroCLRConfig):
15
+ super().__init__()
16
+
17
+ encoder_layer = TransformerEncoderLayer(
18
+ d_model=config.TSlength,
19
+ dim_feedforward=2 * config.TSlength,
20
+ nhead=config.nhead,
21
+ batch_first=True,
22
+ )
23
+ self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer)
24
+
25
+ self.projector = nn.Sequential(
26
+ nn.Linear(config.TSlength, config.projector_out1),
27
+ nn.BatchNorm1d(config.projector_out1),
28
+ nn.ReLU(),
29
+ nn.Linear(config.projector_out1, config.projector_out2),
30
+ )
31
+
32
+ self.normalize_input = config.normalize_input
33
+ self.pooling = config.pooling
34
+ self.TSlength = config.TSlength
35
+
36
+ def forward(self, x):
37
+ # x: [B, 1, 128]
38
+ if self.normalize_input:
39
+ x = F.normalize(x, dim=-1)
40
+
41
+ x = self.transformer_encoder(x) # [B, 1, 128]
42
+
43
+ if self.pooling == "flatten":
44
+ h = x.reshape(x.shape[0], -1) # [B, 128]
45
+ elif self.pooling == "mean":
46
+ h = x.mean(dim=1)
47
+ elif self.pooling == "last":
48
+ h = x[:, -1, :]
49
+ else:
50
+ raise ValueError(f"Unknown pooling='{self.pooling}'")
51
+
52
+ if h.shape[1] != self.TSlength:
53
+ raise ValueError(f"h dim {h.shape[1]} != TSlength {self.TSlength}")
54
+
55
+ z = self.projector(h)
56
+
57
+ return h, z
58
+
59
+
60
+ # --------------------------
61
+ # Your ResNet1D head (verbatim)
62
+ # --------------------------
63
+ class MyConv1dPadSame(nn.Module):
64
+ def __init__(self, in_channels, out_channels, kernel_size, stride, groups=1):
65
+ super().__init__()
66
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, groups=groups)
67
+
68
+ self.kernel_size = kernel_size
69
+ self.stride = stride
70
+
71
+ def forward(self, x):
72
+ in_dim = x.shape[-1]
73
+ out_dim = (in_dim + self.stride - 1) // self.stride
74
+ p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
75
+ pad_left = p // 2
76
+ pad_right = p - pad_left
77
+ x = F.pad(x, (pad_left, pad_right), "constant", 0)
78
+ return self.conv(x)
79
+
80
+
81
+ class MyMaxPool1dPadSame(nn.Module):
82
+ def __init__(self, kernel_size):
83
+ super().__init__()
84
+ self.kernel_size = kernel_size
85
+ self.stride = 1
86
+ self.max_pool = nn.MaxPool1d(kernel_size=kernel_size)
87
+
88
+ def forward(self, x):
89
+ in_dim = x.shape[-1]
90
+ out_dim = (in_dim + self.stride - 1) // self.stride
91
+ p = max(0, (out_dim - 1) * self.stride + self.kernel_size - in_dim)
92
+ pad_left = p // 2
93
+ pad_right = p - pad_left
94
+ x = F.pad(x, (pad_left, pad_right), "constant", 0)
95
+ return self.max_pool(x)
96
+
97
+
98
+ class BasicBlock(nn.Module):
99
+ def __init__(self, in_channels, out_channels, kernel_size, stride, groups, downsample, use_bn, use_do, is_first_block=False):
100
+ super().__init__()
101
+
102
+ self.in_channels = in_channels
103
+ self.out_channels = out_channels
104
+ self.downsample = downsample
105
+ self.use_bn = use_bn
106
+ self.use_do = use_do
107
+ self.is_first_block = is_first_block
108
+
109
+ conv_stride = stride if downsample else 1
110
+
111
+ self.bn1 = nn.BatchNorm1d(in_channels)
112
+ self.relu1 = nn.ReLU()
113
+ self.do1 = nn.Dropout(p=0.75)
114
+ self.conv1 = MyConv1dPadSame(in_channels, out_channels, kernel_size, stride=conv_stride, groups=groups)
115
+
116
+ self.bn2 = nn.BatchNorm1d(out_channels)
117
+ self.relu2 = nn.ReLU()
118
+ self.do2 = nn.Dropout(p=0.75)
119
+ self.conv2 = MyConv1dPadSame(out_channels, out_channels, kernel_size, stride=1, groups=groups)
120
+
121
+ self.max_pool = MyMaxPool1dPadSame(kernel_size=conv_stride)
122
+
123
+ def forward(self, x):
124
+ identity = x
125
+
126
+ out = x
127
+ if not self.is_first_block:
128
+ if self.use_bn:
129
+ out = self.bn1(out)
130
+ out = self.relu1(out)
131
+ if self.use_do:
132
+ out = self.do1(out)
133
+ out = self.conv1(out)
134
+
135
+ if self.use_bn:
136
+ out = self.bn2(out)
137
+ out = self.relu2(out)
138
+ if self.use_do:
139
+ out = self.do2(out)
140
+ out = self.conv2(out)
141
+
142
+ if self.downsample:
143
+ identity = self.max_pool(identity)
144
+
145
+ if self.out_channels != self.in_channels:
146
+ identity = identity.transpose(-1, -2)
147
+ ch1 = (self.out_channels - self.in_channels) // 2
148
+ ch2 = self.out_channels - self.in_channels - ch1
149
+ identity = F.pad(identity, (ch1, ch2), "constant", 0)
150
+ identity = identity.transpose(-1, -2)
151
+
152
+ out += identity
153
+ return out
154
+
155
+
156
+ class ResNet1D(nn.Module):
157
+ def __init__(
158
+ self,
159
+ in_channels,
160
+ base_filters,
161
+ kernel_size,
162
+ stride,
163
+ groups,
164
+ n_block,
165
+ n_classes,
166
+ downsample_gap=2,
167
+ increasefilter_gap=4,
168
+ use_bn=True,
169
+ use_do=True,
170
+ verbose=False
171
+ ):
172
+ super().__init__()
173
+ self.verbose = verbose
174
+ self.n_block = n_block
175
+ self.kernel_size = kernel_size
176
+ self.stride = stride
177
+ self.groups = groups
178
+ self.use_bn = use_bn
179
+ self.use_do = use_do
180
+ self.downsample_gap = downsample_gap
181
+ self.increasefilter_gap = increasefilter_gap
182
+
183
+ self.first_block_conv = MyConv1dPadSame(in_channels, base_filters, kernel_size=self.kernel_size, stride=1)
184
+ self.first_block_bn = nn.BatchNorm1d(base_filters)
185
+ self.first_block_relu = nn.ReLU()
186
+ out_channels = base_filters
187
+
188
+ self.basicblock_list = nn.ModuleList()
189
+ for i_block in range(self.n_block):
190
+ is_first_block = (i_block == 0)
191
+ downsample = (i_block % self.downsample_gap == 1)
192
+
193
+ if is_first_block:
194
+ in_ch = base_filters
195
+ out_ch = in_ch
196
+ else:
197
+ in_ch = int(base_filters * 2 ** ((i_block - 1) // self.increasefilter_gap))
198
+ if (i_block % self.increasefilter_gap == 0) and (i_block != 0):
199
+ out_ch = in_ch * 2
200
+ else:
201
+ out_ch = in_ch
202
+
203
+ block = BasicBlock(
204
+ in_channels=in_ch,
205
+ out_channels=out_ch,
206
+ kernel_size=self.kernel_size,
207
+ stride=self.stride,
208
+ groups=self.groups,
209
+ downsample=downsample,
210
+ use_bn=self.use_bn,
211
+ use_do=self.use_do,
212
+ is_first_block=is_first_block,
213
+ )
214
+ self.basicblock_list.append(block)
215
+ out_channels = out_ch
216
+
217
+ self.final_bn = nn.BatchNorm1d(out_channels)
218
+ self.final_relu = nn.ReLU(inplace=True)
219
+ self.dense = nn.Linear(out_channels, n_classes)
220
+
221
+ def forward(self, x):
222
+ out = self.first_block_conv(x)
223
+ if self.use_bn:
224
+ out = self.first_block_bn(out)
225
+ out = self.first_block_relu(out)
226
+
227
+ for block in self.basicblock_list:
228
+ out = block(out)
229
+
230
+ if self.use_bn:
231
+ out = self.final_bn(out)
232
+ out = self.final_relu(out)
233
+ out = out.mean(-1)
234
+ out = self.dense(out)
235
+ return out
236
+
237
+
238
+ # --------------------------
239
+ # HF model: encoder + ResNet1D head
240
+ # --------------------------
241
+ class NeuroCLRForSequenceClassification(PreTrainedModel):
242
+ """
243
+ Expected input x: [B, 200, 128]
244
+ - runs encoder per ROI: [B,1,128] -> h_r [B,128]
245
+ - stacks into H: [B,200,128]
246
+ - feeds ResNet1D: [B,200,128] -> logits
247
+ """
248
+ config_class = NeuroCLRConfig
249
+ base_model_prefix = "neuroclr"
250
+
251
+ def __init__(self, config: NeuroCLRConfig):
252
+ super().__init__(config)
253
+
254
+ self.encoder = NeuroCLR(config)
255
+
256
+ # Freeze the encoder
257
+ for p in self.encoder.parameters():
258
+ p.requires_grad = False
259
+
260
+ self.head = ResNet1D(
261
+ in_channels=config.n_rois,
262
+ base_filters=config.base_filters,
263
+ kernel_size=config.kernel_size,
264
+ stride=config.stride,
265
+ groups=config.groups,
266
+ n_block=config.n_block,
267
+ n_classes=config.num_labels,
268
+ downsample_gap=config.downsample_gap,
269
+ increasefilter_gap=config.increasefilter_gap,
270
+ use_bn=config.use_bn,
271
+ use_do=config.use_do,
272
+ )
273
+
274
+ self.post_init()
275
+
276
+ def forward(self, x: torch.Tensor, labels: torch.Tensor = None, **kwargs):
277
+ # x: [B, 200, 128]
278
+ if x.ndim != 3 or x.shape[1] != self.config.n_rois or x.shape[2] != self.config.TSlength:
279
+ raise ValueError(
280
+ f"Expected x shape [B,{self.config.n_rois},{self.config.TSlength}] but got {tuple(x.shape)}"
281
+ )
282
+
283
+ B, R, L = x.shape
284
+
285
+ # Encode each ROI independently (ROI-wise SSL)
286
+ hs = []
287
+ for r in range(R):
288
+ xr = x[:, r, :].unsqueeze(1) # [B,1,128]
289
+ with torch.no_grad():
290
+ h, _ = self.encoder(xr)
291
+ # h, _ = self.encoder(xr) # h: [B,128]
292
+ hs.append(h.unsqueeze(1)) # [B,1,128]
293
+
294
+ H = torch.cat(hs, dim=1) # [B,200,128]
295
+
296
+ logits = self.head(H) # head expects [B,200,128]
297
+ loss = None
298
+ if labels is not None:
299
+ loss = nn.CrossEntropyLoss()(logits, labels)
300
+
301
+ return {"loss": loss, "logits": logits}
pretraining/config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "TSlength": 128,
3
+ "architectures": [
4
+ "NeuroCLRModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
8
+ "AutoModel": "modeling_neuroclr.NeuroCLRModel"
9
+ },
10
+ "model_type": "neuroclr",
11
+ "nhead": 2,
12
+ "nlayer": 2,
13
+ "normalize_input": true,
14
+ "pooling": "flatten",
15
+ "projector_out1": 128,
16
+ "projector_out2": 64,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.36.2"
19
+ }
pretraining/configuration_neuroclr.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class NeuroCLRConfig(PretrainedConfig):
4
+ model_type = "neuroclr"
5
+
6
+ def __init__(
7
+ self,
8
+ TSlength: int = 128,
9
+ nhead: int = 2,
10
+ nlayer: int = 2,
11
+ projector_out1: int = 128,
12
+ projector_out2: int = 64,
13
+
14
+ # classification
15
+ num_labels: int = 2,
16
+
17
+ # pooling to avoid flatten dimension mismatch
18
+ pooling: str = "flatten", # "mean" recommended; "flatten" only if seq_len==1
19
+
20
+ normalize_input: bool = True,
21
+ **kwargs
22
+ ):
23
+ super().__init__(**kwargs)
24
+ self.TSlength = TSlength
25
+ self.nhead = nhead
26
+ self.nlayer = nlayer
27
+ self.projector_out1 = projector_out1
28
+ self.projector_out2 = projector_out2
29
+ self.num_labels = num_labels
30
+ self.pooling = pooling
31
+ self.normalize_input = normalize_input
pretraining/export_pretraining_to_hf.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from configuration_neuroclr import NeuroCLRConfig
3
+ from modeling_neuroclr import NeuroCLRModel
4
+
5
+ # ---- EDIT these to match your training ----
6
+ CFG = dict(
7
+ TSlength=128,
8
+ nhead=2,
9
+ nlayer=2,
10
+ projector_out1=128,
11
+ projector_out2=64,
12
+ pooling="flatten", # because input is [B,1,128]
13
+ normalize_input=True,
14
+ )
15
+ CKPT_PATH = ""
16
+ OUT_DIR = "." # saves into pretraining/ folder
17
+ # ------------------------------------------
18
+
19
+ def remap_state_dict(sd):
20
+ new_sd = {}
21
+ for k, v in sd.items():
22
+ k2 = k.replace("module.", "") # if DDP ever used
23
+ if k2.startswith("transformer_encoder.") or k2.startswith("projector."):
24
+ new_sd["neuroclr." + k2] = v
25
+ else:
26
+ # keep anything else as-is (usually none)
27
+ new_sd[k2] = v
28
+ return new_sd
29
+
30
+ def main():
31
+ config = NeuroCLRConfig(**CFG)
32
+
33
+ # This enables AutoModel loading from this folder
34
+ config.auto_map = {
35
+ "AutoConfig": "configuration_neuroclr.NeuroCLRConfig",
36
+ "AutoModel": "modeling_neuroclr.NeuroCLRModel",
37
+ }
38
+
39
+ model = NeuroCLRModel(config)
40
+
41
+ ckpt = torch.load(CKPT_PATH, map_location="cpu")
42
+
43
+ # Your checkpoint uses model_state_dict
44
+ if isinstance(ckpt, dict) and "model_state_dict" in ckpt:
45
+ sd = ckpt["model_state_dict"]
46
+ elif isinstance(ckpt, dict) and "state_dict" in ckpt:
47
+ sd = ckpt["state_dict"]
48
+ else:
49
+ sd = ckpt
50
+
51
+ sd = remap_state_dict(sd)
52
+
53
+ missing, unexpected = model.load_state_dict(sd, strict=False)
54
+ print("Missing:", missing)
55
+ print("Unexpected:", unexpected)
56
+
57
+ model.save_pretrained(OUT_DIR, safe_serialization=True)
58
+ print("Saved HF pretraining model to:", OUT_DIR)
59
+
60
+ if __name__ == "__main__":
61
+ main()
pretraining/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f2a85ac990c09ae2debb3796dd0161d7c8f7c14213e62fb917c481f35296279
3
+ size 1164680
pretraining/modeling_neuroclr.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import TransformerEncoder, TransformerEncoderLayer
5
+
6
+ from transformers import PreTrainedModel
7
+ from configuration_neuroclr import NeuroCLRConfig
8
+
9
+
10
+ class NeuroCLR(nn.Module):
11
+ """
12
+ Transformer expects x: [B, S, TSlength] because d_model = TSlength.
13
+ """
14
+ def __init__(self, config: NeuroCLRConfig):
15
+ super().__init__()
16
+
17
+ encoder_layer = TransformerEncoderLayer(
18
+ d_model=config.TSlength,
19
+ dim_feedforward=2 * config.TSlength,
20
+ nhead=config.nhead,
21
+ batch_first=True,
22
+ )
23
+ self.transformer_encoder = TransformerEncoder(encoder_layer, config.nlayer)
24
+
25
+ self.projector = nn.Sequential(
26
+ nn.Linear(config.TSlength, config.projector_out1),
27
+ nn.BatchNorm1d(config.projector_out1),
28
+ nn.ReLU(),
29
+ nn.Linear(config.projector_out1, config.projector_out2),
30
+ )
31
+
32
+ self.normalize_input = config.normalize_input
33
+ self.pooling = config.pooling
34
+ self.TSlength = config.TSlength
35
+
36
+ def forward(self, x: torch.Tensor):
37
+ # x: [B, S, TSlength]
38
+ if self.normalize_input:
39
+ x = F.normalize(x, dim=-1)
40
+
41
+ x = self.transformer_encoder(x) # [B, S, TSlength]
42
+
43
+ # Make h shape always [B, TSlength]
44
+ if self.pooling == "mean":
45
+ h = x.mean(dim=1) # [B, TSlength]
46
+ elif self.pooling == "last":
47
+ h = x[:, -1, :] # [B, TSlength]
48
+ elif self.pooling == "flatten":
49
+ # ONLY valid if S == 1
50
+ h = x.reshape(x.shape[0], -1)
51
+ if h.shape[1] != self.TSlength:
52
+ raise ValueError(
53
+ f"pooling='flatten' requires seq_len==1 so h dim == TSlength. "
54
+ f"Got h dim {h.shape[1]} vs TSlength {self.TSlength}."
55
+ )
56
+ else:
57
+ raise ValueError(f"Unknown pooling='{self.pooling}'. Use 'mean', 'last', or 'flatten'.")
58
+
59
+ z = self.projector(h)
60
+
61
+ return h, z
62
+
63
+
64
+ class NeuroCLRModel(PreTrainedModel):
65
+ """
66
+ Loads with:
67
+ AutoModel.from_pretrained(..., trust_remote_code=True)
68
+ """
69
+ config_class = NeuroCLRConfig
70
+ base_model_prefix = "neuroclr"
71
+
72
+ def __init__(self, config: NeuroCLRConfig):
73
+ super().__init__(config)
74
+ self.neuroclr = NeuroCLR(config)
75
+ self.post_init()
76
+
77
+ def forward(self, x: torch.Tensor, **kwargs):
78
+ h, z = self.neuroclr(x)
79
+ return {"h": h, "z": z}
upload_to_hf.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from huggingface_hub import create_repo, upload_folder
2
+
3
+ REPO_ID = "SaeedLab/NeuroCLR"
4
+ # create_repo(REPO_ID, repo_type="model", exist_ok=True)
5
+
6
+ upload_folder(
7
+ repo_id=REPO_ID,
8
+ repo_type="model",
9
+ folder_path=".", # uploads pretraining/ and classification/
10
+ )
11
+
12
+ print("Uploaded to:", REPO_ID)