Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sequence parallelism #2412

Merged
merged 44 commits into from
Mar 21, 2025
Merged

Sequence parallelism #2412

merged 44 commits into from
Mar 21, 2025

Conversation

djsaunde
Copy link
Contributor

@djsaunde djsaunde commented Mar 13, 2025

Description

This PR implements sequence parallelism via ring-flash-attn. Specifically, their hf_adapter.py module is used to patch transformers flash attention with llama3_flash_attn_varlen_func, the SP implementation from the llama3 tech report. This technically isn't ring attention, but is the most performant SP variant in most cases.

I think since the batch API (non-sample packing case) is a special case of the varlen API (sample packing case), these changes should be sufficient to cover both cases, but this should be validated with tests.

Motivation and Context

SP is necessary for long context post-training where the VRAM on a single card results in OOM for a single sequence. If a user has >1 GPUs, they can run longer context post-training by enabling this option.

The attention is distributed across the GPUs according to the set sequence_parallel_degree (i.e., if sequence_parallel_degree = 4, then sequences are split into 4 equal-length chunks). Attention is computed on each of the sub-sequences, and then comm is done inter-GPU in order to complete the attention computation.

How has this been tested?

pytest coverage (not super comprehensive) and functional tests.

Screenshots (if appropriate)

Types of changes

  • ring-flash-attn hf_adapter.py integration
  • Data collation changes (sequence splitting, position ID adjustment)
  • AxolotlTrainer sampler, dataloader changes
    • Refactor multipack sampler logic to helper method
    • DistributedSampler for SP case
      • Setting rank = SP group ID allows us to sample data according to SP group
    • Data loader (in the SP case) is not prepared for distributed training by the accelerator object
      • Distribution already handled by the DistributedSampler
  • Bonus: added random_init flag to load model without pretrained weights
  • Bonus: a bit of cleanup

@djsaunde djsaunde self-assigned this Mar 13, 2025
@djsaunde djsaunde force-pushed the sequence-parallelism branch 2 times, most recently from 6c06776 to 6cb26bb Compare March 17, 2025 13:44
@djsaunde
Copy link
Contributor Author

djsaunde commented Mar 17, 2025

Quick test for OOM alleviation:

Running on 1 x H100 SXM (SP disabled / not possible)

Config (note that sequence_len: 16384):

base_model: HuggingFaceTB/SmolLM2-1.7B
# Automatically upload checkpoint and final model to HF
# hub_model_id: username/custom_model_name

load_in_8bit: false
load_in_4bit: false
strict: false

datasets:
  - path: teknium/GPT4-LLM-Cleaned
    type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out

sequence_len: 16384
sample_packing: true
eval_sample_packing: true
pad_to_sequence_len: true

gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_8bit
lr_scheduler: cosine
learning_rate: 0.0002

train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false

gradient_checkpointing: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

warmup_steps: 10
evals_per_epoch: 4
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  pad_token: "<|end_of_text|>"

sequence_parallel_degree: 1

seed: 0

random_init: true

This configuration OOMs during the backward pass on the second step of training (why not on the first step? probably worth looking into this can be explained by the fact that optimizer states are lazily loaded (i.e., don't exist in memory until the end of step 1 of training, likely causing the OOM during the forward pass of step 2)).

Running on 2 x H100 SXM (SP enabled with SP degree = 2)

The above config remains unchanged except for the following change:

sequence_parallel_degree: 2

This configuration does not OOM and training runs with ~1.6s / iteration.

-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.05              Driver Version: 560.35.05      CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:18:00.0 Off |                    0 |
| N/A   45C    P0            490W /  700W |   40826MiB /  81559MiB |     99%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:AB:00.0 Off |                    0 |
| N/A   47C    P0            493W /  700W |   40950MiB /  81559MiB |     99%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
+-----------------------------------------------------------------------------------------+

As an aside, this FFT seems to be taking up more memory than it should, I'm probably not configuring something properly (feel free to comment!) (edit: memory is reduced by ~2x by enable gradient checkpointing 🤦).

@djsaunde
Copy link
Contributor Author

djsaunde commented Mar 17, 2025

Similar to the above, but with 4 x H100s, 32K context length, and SP degree = 4:

+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 565.57.01              Driver Version: 565.57.01      CUDA Version: 12.7     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA H100 80GB HBM3          On  |   00000000:12:00.0 Off |                    0 |
| N/A   44C    P0            159W /  700W |   71108MiB /  81559MiB |      1%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA H100 80GB HBM3          On  |   00000000:43:00.0 Off |                    0 |
| N/A   45C    P0            169W /  700W |   71204MiB /  81559MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA H100 80GB HBM3          On  |   00000000:46:00.0 Off |                    0 |
| N/A   41C    P0            156W /  700W |   71204MiB /  81559MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
|   3  NVIDIA H100 80GB HBM3          On  |   00000000:86:00.0 Off |                    0 |
| N/A   40C    P0            163W /  700W |   70724MiB /  81559MiB |    100%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

@djsaunde djsaunde force-pushed the sequence-parallelism branch 3 times, most recently from 103facf to 9697d28 Compare March 21, 2025 15:41
@djsaunde djsaunde force-pushed the sequence-parallelism branch from ec30c92 to ce35b2a Compare March 21, 2025 16:37
@djsaunde djsaunde merged commit 23f0c51 into main Mar 21, 2025
8 of 11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants