-
Notifications
You must be signed in to change notification settings - Fork 476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] Add Accelerate SFT Trainer #280
Conversation
Thanks @reciprocated , this mostly LGTM, could you update the other examples to also match the |
@cat-state, I've updated those, but perhaps I was overzealous with making |
from typing import Callable, Dict, Iterable, List, Optional, Tuple | ||
|
||
from trlx.data.configs import TRLConfig | ||
from trlx.utils import set_seed | ||
from trlx.utils.loading import get_orchestrator, get_pipeline, get_trainer | ||
|
||
|
||
def train( | ||
def train( # noqa: C901 | ||
config: Optional[TRLConfig] = None, | ||
model_path: Optional[str] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe model path should still be the first arg? It's also whats still in the readme.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, reverted that bit
For the future config changes, we'll still need to keep the ModelConfig part around so keeping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay! This LGTM now
This PR gathers supervised finetuning from these two places[1][2] under the shared api
[1] https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/sft/train_gptj_summarize.py
[2] https://github.com/Dahoas/reward-modeling/blob/main/reward-modeling/finetune_base.py
https://wandb.ai/sorry/trlx/reports/Add-Accelerate-SFT-Trainer-280--VmlldzozNTAyNzQ1