pmukhop commited on
Commit
38cd852
·
0 Parent(s):

Initial walrus commit

Browse files
Files changed (4) hide show
  1. README.md +134 -0
  2. extended_config.yaml +328 -0
  3. walrus.pt +3 -0
  4. walrus.safetensors +3 -0
README.md ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - walrus
4
+ - foundation-model
5
+ - physics
6
+ - continuum-dynamics
7
+ - transformer
8
+ - PDE
9
+ datasets:
10
+ - polymathic-ai/shear_flow
11
+ - polymathic-ai/gray_scott_reaction_diffusion
12
+ - polymathic-ai/active_matter
13
+ - polymathic-ai/turbulent_radiative_layer_2D
14
+ - polymathic-ai/supernova_explosion_64
15
+ - polymathic-ai/turbulence_gravity_cooling
16
+ - polymathic-ai/rayleigh_benard
17
+ - polymathic-ai/planetswe
18
+ - polymathic-ai/acoustic_scattering_inclusions
19
+ - polymathic-ai/MHD_64
20
+ - polymathic-ai/rayleigh_taylor_instability
21
+ - polymathic-ai/acoustic_scattering_discontinuous
22
+ - polymathic-ai/acoustic_scattering_maze
23
+ - polymathic-ai/helmholtz_staircase
24
+ - polymathic-ai/viscoelastic_instability
25
+ - BGLab/FlowBench
26
+ license: mit
27
+ ---
28
+
29
+ # Walrus: A Cross-Domain Foundation Model for Continuum Dynamics
30
+
31
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
32
+ [![GitHub Repo](https://img.shields.io/badge/GitHub-Repo-blue?logo=github)](https://github.com/PolymathicAI/walrus)
33
+ [![arXiv](https://img.shields.io/badge/arXiv-2511.15684-b31b1b.svg)](https://arxiv.org/abs/2511.15684)
34
+
35
+ Walrus is a large-scale **physics foundation model** capable of modeling a broad range of continuum dynamical systems.
36
+
37
+ Walrus is trained jointly across **19 diverse physical domains** spanning:
38
+ - astrophysics
39
+ - geoscience
40
+ - rheology
41
+ - plasma physics
42
+ - acoustics
43
+ - classical fluids
44
+
45
+ These systems have diverse boundary conditions and physical parameterizations. The model is optimized to serve as a **general-purpose surrogate** for physical simulation and a **strong initialization** for downstream fine-tuning on new PDE systems.
46
+
47
+ ---
48
+
49
+ # Model Description
50
+
51
+ Walrus is a **1.3B-parameter space–time Transformer** trained autoregressively to predict the temporal evolution of physical fields. Walrus is trained to model the evolution of physical systems in space and time. A simulation snapshot at time t is written as u(t).
52
+
53
+ We define the difference between two consecutive snapshots as:
54
+ Δu(t+1) = u(t+1) − u(t)
55
+
56
+ Given a short history of snapshots:
57
+ U(t) = [u(t − τ + 1), ..., u(t)]
58
+
59
+ The model predicts the next state using:
60
+ u(t+1) ≈ u(t) + M(U(t))
61
+
62
+ ### Key architectural components
63
+
64
+ - **Adaptive-compute patch embedding**
65
+ - Token count automatically balanced across resolutions
66
+ - Enables mixing 2D and 3D datasets efficiently
67
+
68
+ - **Patch Jittering**
69
+ - A harmonic-analysis–motivated augmentation technique
70
+ - Reduces aliasing and spectral artifacts
71
+ - Improves long-horizon stability across 17/19 pretraining datasets
72
+
73
+ - **Tensor-law–aware data augmentation**
74
+ - 2D data embedded into 3D through plane rotations
75
+ - Vector/tensor fields rotated with correct physical transformations
76
+
77
+ - **Asymmetric normalization**
78
+ - **Asymmetric normalization:** Walrus normalizes inputs by RMS over space-time and de-normalizes the predicted Δu using the RMS of Δ.
79
+
80
+ ---
81
+
82
+ # Pretraining Details
83
+
84
+ Walrus is pretrained 19 physical datasets with:
85
+
86
+ - **Loss**: Per-field normalized L1 loss
87
+ - **Optimizer**: AdamW
88
+ - **Batching**: System-uniform hierarchical sampling
89
+ - **Time-striding**: Random stride (1–5) per training example
90
+ - **Patch jitter range**: Uniform per-axis random offset
91
+ - **Dimensional unification**: 2D fields embedded as thin 3D volumes
92
+
93
+ The model was pretrained on 96 **NVIDIA H100 GPUs** using distributed HSDP (4 GPU per shard group) with sampling matching distribution structure for minimal deadweight loss.
94
+
95
+ ---
96
+
97
+ # Intended Use
98
+
99
+ This pretrained checkpoint is suitable for:
100
+
101
+ ### ✔ Next-step prediction
102
+ ### ✔ Fast surrogate simulation
103
+ ### ✔ Autoregressive rollout of physical systems
104
+ ### ✔ Transfer learning to new physical settings
105
+
106
+ # Resources
107
+
108
+ Paper: https://arxiv.org/pdf/2511.15684
109
+ Github: https://github.com/PolymathicAI/walrus
110
+ Tutorial: https://github.com/PolymathicAI/walrus/demo_notebooks
111
+
112
+ Note, the training code in the repository is closely coupled with tools from [the Well](https://github.com/PolymathicAI/the_well), so
113
+ it can be beneficial to format data to match that schema. If that's not possible, the tutorial does show how one would use the model
114
+ without Well-formatted data.
115
+
116
+
117
+ # Demonstrated downstream tasks
118
+
119
+ We show the strong performance of Walrus by finetuning on a range of challenging downstream tasks as shown in the paper.
120
+ Paths to access the finetuned walrus checkpoints for various downstream tasks is as follows:
121
+
122
+ ### PDEGym CE-RM: https://huggingface.co/polymathic-ai/walrus_ft_CE-RM/tree/main
123
+ ### PDEBench CNS Turbulent: https://huggingface.co/polymathic-ai/walrus_ft_CNS3D_64_Turb/tree/main
124
+ ### PDEBench CNS Random: https://huggingface.co/polymathic-ai/walrus_ft_CNS3D_128_Rand/tree/main
125
+ ### Flowbench FPOSkelenton: https://huggingface.co/polymathic-ai/walrus_ft_flowbench_skelenton/tree/main
126
+ ### The Well Postmerger Neutron Star: https://huggingface.co/polymathic-ai/walrus_ft_post_neutron_star_merger/tree/main
127
+ ### The Well Convective envelope RSG: https://huggingface.co/polymathic-ai/walrus_ft_convective_envelope_rsg/tree/main
128
+ ### PDEArena Conditioned Incompressible NS: https://huggingface.co/polymathic-ai/walrus_ft_pdearena_ins/tree/main
129
+ ### BubbleML 2.0 PoolBoil Subcooled: https://huggingface.co/polymathic-ai/walrus_ft_bubbleML_poolboil/tree/main
130
+
131
+
132
+ Additional checkpoints not included in the Walrus collection on HF can be found [here](https://users.flatironinstitute.org/~polymathic/data/walrus_project_checkpoints/) though the endpoint is a bit finicky.
133
+
134
+ More finetuning checkpoints will continue to be added to HF over time.
extended_config.yaml ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_workers: 10
2
+ name: Walrus-wella-delta-Isotr[Space-Adapt-]-AdamW-0.0002
3
+ automatic_setup: true
4
+ trainer:
5
+ _target_: walrus.trainer.Trainer
6
+ max_epoch: 200
7
+ val_frequency: 10
8
+ rollout_val_frequency: 10
9
+ short_validation_length: 20
10
+ max_rollout_steps: 200
11
+ num_time_intervals: 5
12
+ enable_amp: false
13
+ loss_fn:
14
+ _target_: the_well.benchmark.metrics.MAE
15
+ formatter:
16
+ _target_: hydra.utils.get_class
17
+ path: walrus.data.well_to_multi_transformer.ChannelsFirstWithTimeFormatter
18
+ revin:
19
+ _target_: walrus.trainer.normalization_strat.SamplewiseRevNormalization
20
+ _partial_: true
21
+ prediction_type: delta
22
+ grad_acc_steps: 4
23
+ image_validation: true
24
+ video_validation: true
25
+ gradient_log_level: 0
26
+ clip_gradient: 10
27
+ log_interval: 200
28
+ loss_multiplier: 100.0
29
+ lr_scheduler_per_step: false
30
+ skip_spectral_metrics: true
31
+ optimizer:
32
+ _target_: torch.optim.AdamW
33
+ weight_decay: 0.0001
34
+ eps: 1.0e-10
35
+ lr: 0.0002
36
+ lr_scheduler:
37
+ _target_: walrus.optim.schedulers.InverseSqrtLinearWarmupSqrtCooldown
38
+ warmup_epochs: 10
39
+ cooldown_epochs: 10
40
+ warmup_lr_factor: 0.1
41
+ cooldown_lr_factor: 0.001
42
+ model:
43
+ encoder:
44
+ _partial_: true
45
+ _target_: walrus.models.encoders.vstride_encoder.SpaceBagAdaptiveDVstrideEncoder
46
+ learned_pad: true
47
+ base_kernel_size1d:
48
+ - - 4
49
+ - 4
50
+ base_kernel_size2d:
51
+ - - 8
52
+ - 4
53
+ - - 8
54
+ - 4
55
+ base_kernel_size3d:
56
+ - - 8
57
+ - 4
58
+ - - 8
59
+ - 4
60
+ - - 8
61
+ - 4
62
+ groups: 12
63
+ kernel_scales_seq:
64
+ - - 2
65
+ - 2
66
+ - - 4
67
+ - 2
68
+ - - 4
69
+ - 4
70
+ - - 8
71
+ - 4
72
+ variable_downsample: true
73
+ variable_deterministic_ds: true
74
+ activation:
75
+ _partial_: true
76
+ _target_: torch.nn.SiLU
77
+ decoder:
78
+ _partial_: true
79
+ _target_: walrus.models.decoders.vstride_decoder.AdaptiveDVstrideDecoder
80
+ learned_pad: true
81
+ base_kernel_size1d:
82
+ - - 4
83
+ - 4
84
+ base_kernel_size2d:
85
+ - - 8
86
+ - 4
87
+ - - 8
88
+ - 4
89
+ base_kernel_size3d:
90
+ - - 8
91
+ - 4
92
+ - - 8
93
+ - 4
94
+ - - 8
95
+ - 4
96
+ groups: 12
97
+ activation:
98
+ _partial_: true
99
+ _target_: torch.nn.SiLU
100
+ processor:
101
+ space_mixing:
102
+ _partial_: true
103
+ _target_: walrus.models.spatial_blocks.full_attention.FullAttention
104
+ num_heads: 16
105
+ mlp_dim: null
106
+ time_mixing:
107
+ _partial_: true
108
+ _target_: walrus.models.temporal_blocks.axial_time_attention.AxialTimeAttention
109
+ num_heads: 16
110
+ bias_type: rel
111
+ channel_mixing:
112
+ _partial_: true
113
+ _target_: torch.nn.Identity
114
+ _partial_: true
115
+ _target_: walrus.models.spatiotemporal_blocks.space_time_split.SpaceTimeSplitBlock
116
+ norm_layer:
117
+ _partial_: true
118
+ _target_: walrus.models.shared_utils.normalization.RMSGroupNorm
119
+ _target_: walrus.models.IsotropicModel
120
+ hidden_dim: 1408
121
+ projection_dim: 48
122
+ intermediate_dim: 352
123
+ processor_blocks: 40
124
+ drop_path: 0.05
125
+ groups: 16
126
+ max_d: 3
127
+ static_axes: true
128
+ weight_tied_axes: false
129
+ causal_in_time: true
130
+ include_d:
131
+ - 2
132
+ - 3
133
+ override_dimensionality: 0
134
+ jitter_patches: true
135
+ gradient_checkpointing_freq: 2
136
+ use_periodic_fixed_jitter: true
137
+ input_field_drop: 0.0
138
+ data:
139
+ field_index_map_override:
140
+ closed_boundary: 0
141
+ open_boundary: 1
142
+ bias_correction: 2
143
+ pressure: 3
144
+ velocity_x: 4
145
+ velocity_y: 5
146
+ velocity_z: 6
147
+ zeros_like_density: 7
148
+ speed_of_sound: 8
149
+ concentration: 9
150
+ D_xx: 10
151
+ D_xy: 11
152
+ D_xz: 12
153
+ D_yx: 13
154
+ D_yy: 14
155
+ D_yz: 15
156
+ D_zx: 16
157
+ D_zy: 17
158
+ D_zz: 18
159
+ E_xx: 19
160
+ E_xy: 20
161
+ E_xz: 21
162
+ E_yx: 22
163
+ E_yy: 23
164
+ E_yz: 24
165
+ E_zx: 25
166
+ E_zy: 26
167
+ E_zz: 27
168
+ density: 28
169
+ energy: 29
170
+ velocity_r: 30
171
+ velocity_theta: 31
172
+ velocity_phi: 32
173
+ momentum_x: 33
174
+ momentum_y: 34
175
+ momentum_z: 35
176
+ pressure_re: 36
177
+ pressure_im: 37
178
+ mask: 38
179
+ magnetic_field_x: 39
180
+ magnetic_field_y: 40
181
+ magnetic_field_z: 41
182
+ A: 42
183
+ B: 43
184
+ height: 44
185
+ internal_energy: 45
186
+ temperature: 46
187
+ electron_fraction: 47
188
+ entropy: 48
189
+ magnetic_field_log_r: 49
190
+ magnetic_field_theta: 50
191
+ magnetic_field_phi: 51
192
+ velocity_log_r: 52
193
+ buoyancy: 53
194
+ tracer: 54
195
+ log10_density: 55
196
+ log10_temperature: 56
197
+ c_zz: 57
198
+ C_xx: 58
199
+ C_xy: 59
200
+ C_xz: 60
201
+ C_yx: 61
202
+ C_yy: 62
203
+ C_yz: 63
204
+ C_zx: 64
205
+ C_zy: 65
206
+ C_zz: 66
207
+ transform:
208
+ train:
209
+ _target_: the_well.data.augmentation.RandomRotation90
210
+ p: 1.0
211
+ well_base_path: /mnt/gpuxl/polymathic/the_well/datasets/
212
+ wandb_data_name: well_allmain_only
213
+ module_parameters:
214
+ _target_: walrus.data.MixedWellDataModule
215
+ batch_size: 2
216
+ n_steps_input: 6
217
+ n_steps_output: 1
218
+ min_dt_stride: 1
219
+ max_dt_stride: 5
220
+ max_samples: 2000
221
+ well_dataset_info:
222
+ active_matter:
223
+ include_filters: []
224
+ exclude_filters: []
225
+ planetswe:
226
+ include_filters: []
227
+ exclude_filters: []
228
+ acoustic_scattering_maze:
229
+ include_filters: []
230
+ exclude_filters: []
231
+ field_transforms:
232
+ density: torch.zeros_like
233
+ acoustic_scattering_inclusions:
234
+ include_filters: []
235
+ exclude_filters: []
236
+ field_transforms:
237
+ density: torch.zeros_like
238
+ acoustic_scattering_discontinuous:
239
+ include_filters: []
240
+ exclude_filters: []
241
+ field_transforms:
242
+ density: torch.zeros_like
243
+ euler_multi_quadrants_openBC:
244
+ include_filters: []
245
+ exclude_filters: []
246
+ euler_multi_quadrants_periodicBC:
247
+ include_filters: []
248
+ exclude_filters: []
249
+ gray_scott_reaction_diffusion:
250
+ include_filters: []
251
+ exclude_filters: []
252
+ rayleigh_benard:
253
+ include_filters: []
254
+ exclude_filters: []
255
+ shear_flow:
256
+ include_filters: []
257
+ exclude_filters: []
258
+ turbulent_radiative_layer_2D:
259
+ include_filters: []
260
+ exclude_filters: []
261
+ helmholtz_staircase:
262
+ include_filters: []
263
+ exclude_filters: []
264
+ viscoelastic_instability:
265
+ include_filters: []
266
+ exclude_filters: []
267
+ supernova_explosion_128:
268
+ include_filters: []
269
+ exclude_filters: []
270
+ step_downsample_factor: 0.5
271
+ batch_downsample_factor: 0.5
272
+ field_transforms:
273
+ density: torch.log10
274
+ temperature: torch.log10
275
+ turbulence_gravity_cooling:
276
+ include_filters: []
277
+ exclude_filters: []
278
+ step_downsample_factor: 0.5
279
+ batch_downsample_factor: 0.5
280
+ field_transforms:
281
+ density: torch.log10
282
+ temperature: torch.log10
283
+ turbulent_radiative_layer_3D:
284
+ include_filters: []
285
+ exclude_filters: []
286
+ step_downsample_factor: 0.5
287
+ batch_downsample_factor: 0.5
288
+ field_transforms:
289
+ density: torch.log10
290
+ temperature: torch.log10
291
+ MHD_64:
292
+ include_filters: []
293
+ exclude_filters: []
294
+ step_downsample_factor: 0.5
295
+ batch_downsample_factor: 0.5
296
+ rayleigh_taylor_instability:
297
+ include_filters: []
298
+ exclude_filters: []
299
+ step_downsample_factor: 0.5
300
+ batch_downsample_factor: 0.5
301
+ flowbench_FPO_NS_2D_512x128_harmonics:
302
+ include_filters: []
303
+ exclude_filters: []
304
+ path: /mnt/gpuxl/polymathic/WellFormattedExternalData/flowbench/flowbench_FPO_NS_2D_512x128_harmonics
305
+ auto_resume: true
306
+ folder_override: ''
307
+ checkpoint_override: ''
308
+ config_override:
309
+ validation_mode: false
310
+ frozen_components:
311
+ - model
312
+ distribution:
313
+ distribution_type: hsdp
314
+ local_size: 4
315
+ logger:
316
+ wandb: true
317
+ wandb_project_name: walrus_Training_Attempts
318
+ checkpoint:
319
+ _target_: walrus.trainer.checkpoints.CheckPointer
320
+ save_dir: /mnt/home/polymathic/ceph/walrus_logging/runs/Walrus_ft_major_v2-wella-delta-Isotr[Space-Adapt-]-AdamW-0.0002/0/checkpoints
321
+ load_checkpoint_path: null
322
+ coalesced_checkpoint_path: null
323
+ save_best: true
324
+ checkpoint_frequency: 20
325
+ align_fields: true
326
+ load_chkpt_after_finetuning_expansion: false
327
+ finetuning_mods: {}
328
+ experiment_dir: /mnt/home/polymathic/ceph/walrus_logging/runs
walrus.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7c5338a8ca88cdc36f8479dc4fe136416fed0d0b82521380998d2a14c8a01c3f
3
+ size 5145064530
walrus.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d96dc428879c51a9d979f3d855cf2843ebb3e29790190fab34226db8aeec194
3
+ size 5144892644