Skip to content

Commit 1c22396

Browse files
committed
Revert "Save in torch format"
This reverts commit 613715c.
1 parent 613715c commit 1c22396

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

scripts/run_dataloader.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict
44
from tqdm import tqdm
55

6+
import safetensors.torch
67
import torch
78
import torch.distributed as dist
89
import torch.multiprocessing as mp
@@ -70,9 +71,9 @@ def main(cfg: TrainConfig, output_dir: Path) -> None:
7071
if batches_read >= batches_per_file:
7172
file_start = batch_number - batches_per_file + 1
7273
file_end = batch_number + 1
73-
for name, t in name_to_batches.items():
74-
filename = output_dir / f"{name}-{file_start}-{file_end}.pt"
75-
torch.save(t[:batches_read], filename)
74+
filename = output_dir / f"{file_start}-{file_end}.safetensors"
75+
truncated_tensors = {n: t[:batches_read] for n, t in name_to_batches.items()}
76+
safetensors.torch.save_file(truncated_tensors, filename)
7677
batches_read = 0
7778

7879

0 commit comments

Comments
 (0)