Abstract
The Llama-3.1-8B-Instruct-w16a16-1node model is a Turkish legal instruction-tuned variant of Meta’s Llama-3.1-8B-Instruct, trained fully in bfloat16 precision using the default Tensorwise quantization scaling recipe. This model serves as the baseline reference for the “FSDP2 with Float8 Precision for Faster Training” study, where BF16 FSDP2 performance was compared against multiple FP8 mixed-precision configurations. The model was fine-tuned on the newmindai/EuroHPC-Legal dataset (Turkish legal multi-domain Q/A) using pure BF16 for both forward and backward computation, providing a stable, high-precision foundation against which FP8 speedups, efficiency gains, and convergence behavior were benchmarked. Trained on a single EuroHPC/BSC-class node with four NVIDIA H100 GPUs, this configuration achieved consistent convergence and served as the control setup for evaluating all Float8 recipe variants.
Experiment Context
This model was trained as part of our study for comparing FSDP2 with bfloat16 precision against FSDP2 with FP8 mixed precision bfp16-fp8.
We used meta-llama/Llama-3.1-8B-Instruct. The model has been loaded using torch_dtype = bfloat16 and wrapped at once, also during forward/backward passes bfloat16 has been used for computations.
from torch.distributed._composable.fsdp import fully_shard
mesh_device_type = "cuda" if use_cuda else "cpu"
mesh = DeviceMesh(mesh_device_type, list(range(world_size)))
fsdp_kwargs = {
"mesh": mesh,
"reshard_after_forward": True,
}
model = fully_shard(model, **fsdp_kwargs)
Base Model Technical Specifications
- Parameters: 8 Billion
- Architecture Family: Llama 3.1
- Maximum Position Embeddings: 131,072
- Attention Heads: 32 (
num_attention_heads) - Key-Value Heads: 8 (
num_key_value_heads) - Hidden Layers: 32 (
num_hidden_layers) - Hidden Size: 4,096 (
hidden_size) - Intermediate Size: 14,336
- Vocabulary Size: 128,256
- Precision: bfloat16
- RoPE Scaling: type
llama3, factor = 8.0 - RMS Norm Epsilon: 1e-05
- Activation: SiLU
Training Methodology
Training Configuration
- Model:
meta-llama/Llama-3.1-8B-Instruct - Sequence Length: 4,096 (
seq_len) - Epochs: 1
- Per-Device Micro Batch Size: 2
- Gradient Accumulation: 4
- GPUs: 4 (via
CUDA_VISIBLE_DEVICES=0,1,2,3) - dtype:
bf16&&fp8=false- Weights: bfloat16
- Activations: bfloat16
- Optimizer: AdamW
- Learning Rate: 2e-5
- Weight Decay: 0.01
- Betas: (0.9, 0.95)
- Epsilon: 1e-8
- LR Scheduler: Cosine; warmup = 10% (
warmup_ratio=0.1) | alsowarmup_steps=100 - Max Grad Norm: 1.0
- Gradient Checkpointing: Enabled
- Evaluation: every 5 steps (
eval_steps=5,eval_samples=1000) - Checkpointing: every 10 steps; keep last 5; select best by
eval_loss - Logging: every step to file; Weights & Biases in offline mode
- Seed: 100
- Distributed Training:
torch.distributed.run(single node, multi-GPU)- FSDP2 (Optimized Fully Sharded Data Parallel)
Setups
- Precision: Used Half-precision bfloat16 as data type and for computation.
- Hardware: HPC (EuroHPC/BSC-class) node with 4 × NVIDIA H100 GPUs.
- Framework: PyTorch with
torchrunfor distributed training.
Dependencies
| package | Version |
|---|---|
| Transformers | 4.57.1 |
| torch | 2.9.0+cu128 |
| accelerate | 0.14.1 |
| datasets | 4.3.0 |
| huggingface-hub | 0.36.0 |
| tensorboard | 2.20.0 |
| tensorboard-data-server | 0.7.2 |
| wandb | 0.22.1 |
Job Details
| model | Job ID | Runtime (mins) | Nodes | GPUs | Node-hour | GPU-hour | micro-batch | batch-size | gradient_accumulation | total_batch_size |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct_w16a8_rw | 31768103 | 115.75 | 1 | 4 | 1.929 | 7.716 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp | 31837629 | 109.00 | 1 | 4 | 1.816 | 7.266 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a8-mxtw | 31768031 | 64.00 | 1 | 4 | 1.066 | 4.266 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a16-tw | 31768074 | 138.75 | 1 | 4 | 2,312 | 9,25 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a8-1node-bs8 | 31768093 | 123.75 | 1 | 4 | 2.062 | 8,250 | 2 | 2 | 4 | 32 |
| Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 | 31478433 | 31.75 | 4 | 4 | 2.117 | 8.467 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 | 31478468 | 39.75 | 4 | 4 | 2.650 | 10.600 | 4 | 4 | 8 | 512 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 | 31476914 | 22.00 | 8 | 4 | 2.933 | 11.733 | 4 | 4 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 | 31476844 | 23.50 | 8 | 4 | 3.133 | 12.533 | 4 | 4 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 | 31476914 | 22.00 | 8 | 4 | 2.933 | 11.733 | 4 | 8 | 8 | 1024 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 | 31476844 | 23.50 | 8 | 4 | 3.133 | 12.533 | 4 | 8 | 8 | 1024 |
Training Time Analysision
| Model | Training Time (mins) | Memory Allocated (avg %) | GPU Utilization (avg %) | Speed vs bf16 |
|---|---|---|---|---|
| Llama-3.1-8B-Instruct_w16a16-tw | 138.75267 | 74.4189 | 56.6059% | _ |
| Llama-3.1-8B-Instruct-w16a8-1node-bs8 | 123.75267 | 68.8982 | 97.5364% | 12.11% |
| Llama-3.1-8B-Instruct_w16a8_rw | 115.75364 | 69.6132 | 97.7689% | 19.87% |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp | 109.00364 | 69.4806 | 97.3312% | 27.33% |
| Llama-3.1-8B-Instruct-w16a8-mxtw | 64.00328 | 68.8982 | 95.5661% | 116.82% |
All 15-models trained on(1Node,4Noes,8Nodes with both bfp16-fp8 && bfp16 configurations and fp8 recipes)
| Model | Max Loss (train) | Min Loss (train) | Avg Loss (train) | Final Loss (train) | ± Std (train) | Max Loss (val) | Min Loss (val) | Avg Loss (val) | Final Loss (val) | ± Std (val) |
|---|---|---|---|---|---|---|---|---|---|---|
| Llama-3.1-8B-Instruct-w16a8-rw | 8 | 3.1682 | 0.5740 | 0.8118 | 0.6431 | 0.2746 | 1.0613 | 0.8394 | 0.8937 | 0.8394 |
| Llama-3.1-8B-Instruct_w16a8_rw_with_gw_hp | 8 | 3.1837 | 0.5763 | 0.8116 | 0.6420 | 0.2751 | 1.0599 | 0.8391 | 0.8933 | 0.8391 |
| Llama-3.1-8B-Instruct-w16a8-mxtw | 8 | 3.1983 | 0.5747 | 0.8115 | 0.6446 | 0.2758 | 1.0562 | 0.8384 | 0.8923 | 0.8384 |
| Llama-3.1-8B-Instruct-w16a16-tw | 8 | 3.1235 | 0.7203 | 0.9750 | 0.3344 | 0.7612 | 1.9113 | 0.8907 | 0.9831 | 0.1897 |
| Llama-3.1-8B-Instruct-w16a8-1node-bs8 | 8 | 3.1661 | 0.7261 | 0.9804 | 0.3374 | 0.7672 | 1.9230 | 0.8948 | 0.9867 | 0.1906 |
| Llama-3.1-8B-Instruct-w16a16-4nodes-bs32 | 32 | 3.2452 | 0.7414 | 0.9665 | 0.4844 | 0.7504 | 1.0538 | 0.8382 | 0.8844 | 0.0725 |
| Llama-3.1-8B-Instruct-w16a8-4nodes-bs32 | 32 | 3.2840 | 0.7478 | 0.9748 | 0.4905 | 0.7581 | 1.0701 | 0.8430 | 0.8922 | 0.0764 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs32 | 32 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs32 | 32 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 |
| Llama-3.1-8B-Instruct-w16a16-8nodes-bs64 | 64 | 3.2311 | 0.8448 | 1.1856 | 0.6434 | 0.8448 | 1.0257 | 0.8977 | 0.9460 | 0.0568 |
| Llama-3.1-8B-Instruct-w16a8-8nodes-bs64 | 64 | 3.3003 | 0.8473 | 1.1866 | 0.6481 | 0.8473 | 1.0203 | 0.8992 | 0.9445 | 0.0539 |
Implementation
Gpu && Memory usage Profiling
To visualize the usage of the memory and gpu's vram/utilis and more you can profile them with pytorch-profiler tool or nsight systems profiler tool .
- follow the steps to visualize with Pytorch-profiler :
- pip install the versions that mentioned in the dependencies section of these libs tensorboard and tensorboard-data-server.
- Visualize pytorch profiles by runing the command provided below.
tensorboard --logdir="./Llama-3.1-8B-Instruct_w16a16" --port="6006"
follow the steps to visualize with nsight systems:
- download the files on to you local machine.
- download the full_version of the nsys profiler tool from here : https://developer.nvidia.com/nsight-systems/get-started.
- open the nsys-rep profiles using the nsight systems gui .
also you can get a summary of the report using these command lines up to your porpuse :
nsys stats --report cuda_gpu_kern_sum /path/to/lama3.1_fp16_baseline.nsys-repnsys stats --report cuda_gpu_trace /path/to/lama3.1_fp16_baseline.nsys-repnsys stats --report cuda_gpu_mem_size_sum /path/to/lama3.1_fp16_baseline.nsys-rep
Usage
Note: the final model has been saved in bfloat16 format. For inference, load the model in bfloat16 or float16 as shown below:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
model_name = "newmindai/Llama-3.1-8B-Instruct-w16a16-1node"
dtype = torch.bfloat16
tok = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto"
)
prompt = "Soru: Kişisel Verilerin Korunması Kanunu uyarınca hangi durumlarda açık rıza aranmaz? Cevap:"
inputs = tok(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=256,
do_sample=False
)
print(tok.decode(out[0], skip_special_tokens=True))
Ethical Considerations and Disclaimers
- Research & development purposes only; not a substitute for professional legal counsel.
- Users must ensure compliance with data protection and sector regulations.
- Potential biases may exist in domain data and model outputs.
Model & Data Card Metadata
- Total Parameters: 8,030,261,248
- Serialized Size (approx.): 16,060,522,496 bytes
- Config precision: bfloat16
- RoPE: llama3 scaling, factor 8.0
References and Citations
Base Model
@misc{meta_llama31_8b_instruct,
title={Llama 3.1 8B Instruct},
author={Meta AI},
year={2024},
howpublished={\url{https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct}}
}
Training Dataset
@misc{euro_hpc_legal,
title={EuroHPC-Legal},
author={newmindai},
year={2025},
howpublished={\url{https://huggingface.co/datasets/newmindai/EuroHPC-Legal}}
}
- Downloads last month
- 57
Model tree for newmindai/Llama-3.1-8B-Instruct-w16a16-tw
Base model
meta-llama/Llama-3.1-8B




