-
Notifications
You must be signed in to change notification settings - Fork 615
TP + FP8 - NotImplementedError for certain operations #2629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Thanks @nathan-az. Just to clarify |
I've created a separate issue for the FP8 + TP + compile support so these can be tackled separately. |
@ebsmothers I haven't done any testing with LLaMA-4 but I think you may have fixed this one with 0d90675? Config: batch_size: 8
checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_files:
filename_format: model-{}-of-{}.safetensors
max_filename: '00030'
model_type: LLAMA3
output_dir: ${output_dir}
recipe_checkpoint: null
checkpoint_dir: models/llama_3_3_70b
clip_grad_norm: 100.0
compile: false
custom_sharded_layers: []
data_parallel_shard_dim: 1
data_parallel_replicate_dim: 1
tensor_parallel_dim: 8
enable_fp8_training: true
fp8_recipe_name: tensorwise
tensor_parallel_plan:
_component_: torchtune.models.llama3._parallelism._fp8_llama_tp_plan
dataset:
_component_: torchtune.datasets.alpaca_dataset
packed: true
train_on_input: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
activation_offloading_use_streams: false
epochs: 1
fsdp_cpu_offload: false
gradient_accumulation_steps: 1
log_every_n_steps: 1
log_level: INFO
log_peak_memory_stats: true
loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
max_steps_per_epoch: 5
metric_logger:
_component_: torchtune.training.metric_logging.MLFlowLogger
experiment_name: llama_3_3_70b_batch_sweep
run_name: tp8_bs8_ga1
optimizer:
_component_: torchao.optim.AdamW8bit
lr: 4.0e-05
optimizer_in_bwd: false
output_dir: outputs
resume_from_checkpoint: false
seed: 100
shuffle: true
tokenizer:
max_seq_len: 2048
path: models/llama_3_3_70b/original/tokenizer.model
_component_: torchtune.models.llama3.llama3_tokenizer
model:
_component_: torchtune.models.llama3_3.llama3_3_70b
Then close this issue |
FP8 training is now supported #2546, but has issues with tensor parallelism which is currently gated. MVP for this feature should include:
enable_fp8_training
with setting atensor_parallel_plan
torch.compile
This issue is to track support and the request. It's not clear to me the scope of what needs to be done to support this. @andrewor14 feel free to comment if there are other requirements for MVP for this feature, or if you want to clarify the scope.
The text was updated successfully, but these errors were encountered: