-
-
Notifications
You must be signed in to change notification settings - Fork 985
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sequence parallelism #2412
Sequence parallelism #2412
Conversation
6c06776
to
6cb26bb
Compare
Quick test for OOM alleviation: Running on 1 x H100 SXM (SP disabled / not possible)Config (note that base_model: HuggingFaceTB/SmolLM2-1.7B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: teknium/GPT4-LLM-Cleaned
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
sequence_len: 16384
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: "<|end_of_text|>"
sequence_parallel_degree: 1
seed: 0
random_init: true This configuration OOMs during the backward pass on the second step of training ( Running on 2 x H100 SXM (SP enabled with SP degree = 2)The above config remains unchanged except for the following change: sequence_parallel_degree: 2 This configuration does not OOM and training runs with ~1.6s / iteration. -----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05 Driver Version: 560.35.05 CUDA Version: 12.6 |
|-----------------------------------------+------------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+========================+======================|
| 0 NVIDIA H100 80GB HBM3 On | 00000000:18:00.0 Off | 0 |
| N/A 45C P0 490W / 700W | 40826MiB / 81559MiB | 99% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
| 1 NVIDIA H100 80GB HBM3 On | 00000000:AB:00.0 Off | 0 |
| N/A 47C P0 493W / 700W | 40950MiB / 81559MiB | 99% Default |
| | | Disabled |
+-----------------------------------------+------------------------+----------------------+
+-----------------------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+ As an aside, this FFT seems to be taking up more memory than it should, I'm probably not configuring something properly (feel free to comment!) (edit: memory is reduced by ~2x by enable gradient checkpointing 🤦). |
Similar to the above, but with 4 x H100s, 32K context length, and SP degree = 4:
|
103facf
to
9697d28
Compare
ec30c92
to
ce35b2a
Compare
Description
This PR implements sequence parallelism via
ring-flash-attn
. Specifically, their hf_adapter.py module is used to patch transformers flash attention withllama3_flash_attn_varlen_func
, the SP implementation from the llama3 tech report. This technically isn't ring attention, but is the most performant SP variant in most cases.I think since the batch API (non-sample packing case) is a special case of the varlen API (sample packing case), these changes should be sufficient to cover both cases, but this should be validated with tests.
Motivation and Context
SP is necessary for long context post-training where the VRAM on a single card results in OOM for a single sequence. If a user has >1 GPUs, they can run longer context post-training by enabling this option.
The attention is distributed across the GPUs according to the set
sequence_parallel_degree
(i.e., ifsequence_parallel_degree = 4
, then sequences are split into 4 equal-length chunks). Attention is computed on each of the sub-sequences, and then comm is done inter-GPU in order to complete the attention computation.How has this been tested?
pytest
coverage (not super comprehensive) and functional tests.Screenshots (if appropriate)
Types of changes
ring-flash-attn
hf_adapter.py
integrationAxolotlTrainer
sampler, dataloader changesDistributedSampler
for SP caseDistributedSampler
random_init
flag to load model without pretrained weights