Skip to content

Commit f534a7f

Browse files
authoredMay 2, 2024
[Feature]: add swanlab logger (#7)
1 parent cdce915 commit f534a7f

6 files changed

+41
-11
lines changed
 

‎.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ datasets/
77
experiments_t2m/
88
experiments_t2m_test/
99
experiments_control/
10-
experiments_control_test/
10+
experiments_control_test/

‎README.md

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ Experimental results demonstrate the remarkable generation and controlling capab
4141

4242
## 📢 News
4343

44+
- **[2024/05/02]** We support the [SwanLab](https://github.com/SwanHubX/SwanLab) logger, please refer to this [PR](https://github.com/Dai-Wenxun/MotionLCM/pull/7) for details.
4445
- **[2024/05/01]** Upload paper and release code.
4546

4647
## 👨‍🏫 Quick Start

‎mld/config.py

+2
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def parse_args() -> DictConfig:
3535
parser.add_argument('--example', type=str, required=False, help="input text and lengths with txt format")
3636
parser.add_argument('--no-plot', action="store_true", required=False, help="whether plot the skeleton-based motion")
3737
parser.add_argument('--replication', type=int, default=1, help="the number of replication of sampling")
38+
parser.add_argument('--vis', type=str, default="tb", choices=['tb', 'swanlab'], help="the visualization method, tensorboard or swanlab")
3839
args = parser.parse_args()
3940

4041
cfg = OmegaConf.load(args.cfg)
@@ -44,4 +45,5 @@ def parse_args() -> DictConfig:
4445
cfg.example = args.example
4546
cfg.no_plot = args.no_plot
4647
cfg.replication = args.replication
48+
cfg.vis = args.vis
4749
return cfg

‎requirements.txt

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ torch==1.13.1
33
gdown
44
omegaconf
55
rich
6+
swanlab==0.3.1
67
torchmetrics==1.3.2
78
scipy==1.11.2
89
matplotlib==3.3.4
@@ -14,4 +15,4 @@ h5py==3.11.0
1415
smplx==0.1.28
1516
chumpy==0.70
1617
numpy==1.23.1
17-
natsort==8.4.0
18+
natsort==8.4.0

‎train_motion_control.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from omegaconf import OmegaConf
99

1010
import torch
11+
import swanlab
1112
import diffusers
1213
import transformers
1314
from torch.utils.tensorboard import SummaryWriter
@@ -30,8 +31,14 @@ def main():
3031
output_dir = osp.join(cfg.FOLDER, name_time_str)
3132
os.makedirs(output_dir, exist_ok=False)
3233
os.makedirs(f"{output_dir}/checkpoints", exist_ok=False)
33-
34-
writer = SummaryWriter(output_dir)
34+
35+
if cfg.vis == "tb":
36+
writer = SummaryWriter(output_dir)
37+
elif cfg.vis == "swanlab":
38+
run = swanlab.init(project="MotionLCM", experiment_name=os.path.normpath(output_dir).replace(os.path.sep, "-"),
39+
suffix=None, config=cfg, logdir=output_dir)
40+
else:
41+
raise ValueError(f"Invalid vis method: {cfg.vis}")
3542

3643
stream_handler = logging.StreamHandler(sys.stdout)
3744
file_handler = logging.FileHandler(osp.join(output_dir, 'output.log'))
@@ -133,7 +140,10 @@ def validation():
133140
min_val_tj = metrics['Metrics/traj_fail_50cm']
134141
print_table(f'Metrics@Step-{global_step}', metrics)
135142
for k, v in metrics.items():
136-
writer.add_scalar(k, v, global_step=global_step)
143+
if cfg.vis == "tb":
144+
writer.add_scalar(k, v, global_step=global_step)
145+
elif cfg.vis == "swanlab":
146+
run.log({k: v}, step=global_step)
137147

138148
model.controlnet.train()
139149
model.traj_encoder.train()
@@ -189,7 +199,10 @@ def validation():
189199
"diff_loss": diff_loss.detach().item(), 'cond_loss': cond_loss.detach().item(), 'rot_loss': rot_loss.detach().item()}
190200
progress_bar.set_postfix(**logs)
191201
for k, v in logs.items():
192-
writer.add_scalar(k, v, global_step=global_step)
202+
if cfg.vis == "tb":
203+
writer.add_scalar(k, v, global_step=global_step)
204+
elif cfg.vis == "swanlab":
205+
run.log({k: v}, step=global_step)
193206

194207
if global_step >= cfg.TRAIN.max_train_steps:
195208
break

‎train_motionlcm.py

+18-5
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from omegaconf import OmegaConf
1111

1212
import torch
13+
import swanlab
1314
import diffusers
1415
import transformers
1516
import torch.nn.functional as F
@@ -127,7 +128,13 @@ def main():
127128
os.makedirs(output_dir, exist_ok=False)
128129
os.makedirs(f"{output_dir}/checkpoints", exist_ok=False)
129130

130-
writer = SummaryWriter(output_dir)
131+
if cfg.vis == "tb":
132+
writer = SummaryWriter(output_dir)
133+
elif cfg.vis == "swanlab":
134+
run = swanlab.init(project="MotionLCM", experiment_name=os.path.normpath(output_dir).replace(os.path.sep, "-"),
135+
suffix=None, config=cfg, logdir=output_dir)
136+
else:
137+
raise ValueError(f"Invalid vis method: {cfg.vis}")
131138

132139
stream_handler = logging.StreamHandler(sys.stdout)
133140
file_handler = logging.FileHandler(osp.join(output_dir, 'output.log'))
@@ -245,7 +252,10 @@ def validation():
245252
min_val_fid = metrics['Metrics/FID']
246253
print_table(f'Metrics@Step-{global_step}', metrics)
247254
for k, v in metrics.items():
248-
writer.add_scalar(k, v, global_step=global_step)
255+
if cfg.vis == "tb":
256+
writer.add_scalar(k, v, global_step=global_step)
257+
elif cfg.vis == "swanlab":
258+
run.log({k: v}, step=global_step)
249259
base_model.train()
250260
return max_val_rp1, min_val_fid
251261

@@ -411,9 +421,12 @@ def validation():
411421

412422
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
413423
progress_bar.set_postfix(**logs)
414-
writer.add_scalar('loss', logs['loss'], global_step=global_step)
415-
writer.add_scalar('lr', logs['lr'], global_step=global_step)
416-
424+
if cfg.vis == "tb":
425+
writer.add_scalar('loss', logs['loss'], global_step=global_step)
426+
writer.add_scalar('lr', logs['lr'], global_step=global_step)
427+
elif cfg.vis == "swanlab":
428+
run.log({'loss': logs['loss'], 'lr': logs['lr']}, step=global_step)
429+
417430
if global_step >= cfg.TRAIN.max_train_steps:
418431
break
419432

0 commit comments

Comments
 (0)
Please sign in to comment.