Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable training with Tensorboard tracking #209

Merged
merged 13 commits into from
Jan 24, 2023

Conversation

marcobellagente93
Copy link
Contributor

@marcobellagente93 marcobellagente93 commented Jan 22, 2023

Currently only wandb logging is supported, and using tensorboard will result in a number of small blockers.

This small PR removes the following blockers:

  • a logging_dir is added to the training config to be used by trackers needing a local folder (e.g. tensorboard)
  • an assertion is places before calling init_trackers to check the specified tracker
  • a flattened config is used for Tensorboard logging
  • Optional slight renaming (trackers -> tracker) and interface change assuming a single tracker is used

Disclaimers

CC @daia99

@daia99
Copy link
Contributor

daia99 commented Jan 22, 2023

We may want to remove the tensorboard files for the PR!

@marcobellagente93
Copy link
Contributor Author

marcobellagente93 commented Jan 22, 2023

We may want to remove the tensorboard files for the PR!

woops, got it

@@ -78,19 +78,32 @@ def __init__(self, config, **kwargs):
dist_config = get_distributed_config(self.accelerator)
config_dict["distributed"] = dist_config
init_trackers_kwargs = {}
if "wandb" in config.train.trackers:
# HACK: Tensorboard doesn't like nested dict as hyperparams
config_dict_flat = {a:b for (k,v) in config_dict.items() for (a,b) in v.items() if not isinstance(b, dict)}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use trlx.utils.modeling.flatten_dict here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for the suggestion! I replaced the dict comprehension with a call to flatten_dict(), since tensorboard also doesn't like lists I added a couple lines to split the optimizer betas

config=config_dict,
init_kwargs=init_trackers_kwargs,
)
else:
Copy link
Collaborator

@cat-state cat-state Jan 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!
Could you add back a comment explaining what this branch is for and the flattening? Aside from that it LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure!

  • I had run into problems running trlx without wandb (I don't have an account as of now), and found an already opened issue precisely on this. The branch has minor modifications (which don't change the previous interface for wandb users) to allow tensorboard tracking
  • The only tricky part is that wandb is pretty fancy and take nested dicts as logging params, this is not the case for tensorboard, hence the experiment config is fully flattened, and the only list is simply split apart (for the same reason)
    Do let me know if anything is not clear or if I should add comments in the tensorboard specific logging

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry, by branch I meant the else branch - i.e a short comment like
else: # tracker == 'tensorboard' and # flatten config for tensorboard, flatten lists in hparams into flattened config

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, I was reading the comment in between my regular work and got fully confused, thanks for the feedback

@cat-state
Copy link
Collaborator

Thanks for contributing! This LGTM now.

@cat-state cat-state merged commit 82435d8 into CarperAI:main Jan 24, 2023
alan-cooney added a commit to skyhookadventure/trlx that referenced this pull request Jan 25, 2023
This bug was introduced by CarperAI#209, which changed the `trackers` config property to `tracker` but didn't update this use case.
jon-tow pushed a commit that referenced this pull request Jan 25, 2023
This bug was introduced by #209, which changed the `trackers` config property to `tracker` but didn't update this use case.
@marcobellagente93 marcobellagente93 deleted the enable-tensorboard branch January 29, 2023 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants