Skip to content

Commit 613715c

Browse files
committed
Save in torch format
1 parent 44f3c2d commit 613715c

File tree

1 file changed

+3
-4
lines changed

1 file changed

+3
-4
lines changed

scripts/run_dataloader.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from typing import Dict
44
from tqdm import tqdm
55

6-
import safetensors.torch
76
import torch
87
import torch.distributed as dist
98
import torch.multiprocessing as mp
@@ -71,9 +70,9 @@ def main(cfg: TrainConfig, output_dir: Path) -> None:
7170
if batches_read >= batches_per_file:
7271
file_start = batch_number - batches_per_file + 1
7372
file_end = batch_number + 1
74-
filename = output_dir / f"{file_start}-{file_end}.safetensors"
75-
truncated_tensors = {n: t[:batches_read] for n, t in name_to_batches.items()}
76-
safetensors.torch.save_file(truncated_tensors, filename)
73+
for name, t in name_to_batches.items():
74+
filename = output_dir / f"{name}-{file_start}-{file_end}.pt"
75+
torch.save(t[:batches_read], filename)
7776
batches_read = 0
7877

7978

0 commit comments

Comments
 (0)