Skip to content

TP + FP8 - NotImplementedError for certain operations #2629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
nathan-az opened this issue Apr 23, 2025 · 3 comments · May be fixed by pytorch/ao#2154
Open

TP + FP8 - NotImplementedError for certain operations #2629

nathan-az opened this issue Apr 23, 2025 · 3 comments · May be fixed by pytorch/ao#2154

Comments

@nathan-az
Copy link
Contributor

FP8 training is now supported #2546, but has issues with tensor parallelism which is currently gated. MVP for this feature should include:

  • Plug-and-play support for enable_fp8_training with setting a tensor_parallel_plan
  • Compatibility with torch.compile

This issue is to track support and the request. It's not clear to me the scope of what needs to be done to support this. @andrewor14 feel free to comment if there are other requirements for MVP for this feature, or if you want to clarify the scope.

@andrewor14
Copy link
Contributor

Thanks @nathan-az. Just to clarify torch.compile is already supported. It's just not compatible with tensor parallel yet

@nathan-az nathan-az changed the title Support tensor parallelism with FP8 training TP + FP8 - NotImplementedError for certain operations May 7, 2025
@nathan-az
Copy link
Contributor Author

I've created a separate issue for the FP8 + TP + compile support so these can be tackled separately.

@nathan-az
Copy link
Contributor Author

nathan-az commented May 30, 2025

@ebsmothers I haven't done any testing with LLaMA-4 but I think you may have fixed this one with 0d90675?

Config:

batch_size: 8
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_files:
    filename_format: model-{}-of-{}.safetensors
    max_filename: '00030'
  model_type: LLAMA3
  output_dir: ${output_dir}
  recipe_checkpoint: null
  checkpoint_dir: models/llama_3_3_70b
clip_grad_norm: 100.0
compile: false
custom_sharded_layers: []
data_parallel_shard_dim: 1
data_parallel_replicate_dim: 1
tensor_parallel_dim: 8
enable_fp8_training: true
fp8_recipe_name: tensorwise
tensor_parallel_plan:
  _component_: torchtune.models.llama3._parallelism._fp8_llama_tp_plan
dataset:
  _component_: torchtune.datasets.alpaca_dataset
  packed: true
  train_on_input: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
activation_offloading_use_streams: false
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
  _component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: 5
metric_logger:
  _component_: torchtune.training.metric_logging.MLFlowLogger
  experiment_name: llama_3_3_70b_batch_sweep
  run_name: tp8_bs8_ga1
optimizer:
  _component_: torchao.optim.AdamW8bit
  lr: 4.0e-05
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tokenizer:
  max_seq_len: 2048
  path: models/llama_3_3_70b/original/tokenizer.model
  _component_: torchtune.models.llama3.llama3_tokenizer
model:
  _component_: torchtune.models.llama3_3.llama3_3_70b

compile still complains, but has a new error. I'll update the compile issue, but if you or anybody can repro that TP + FP8 works, I can make a little PR with the following changes (only changes I had to make to test):

  • ungate fp8 + tp in the recipe
  • update the fp8 tp plan to take a model arg (unused)

Then close this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants