diff --git a/examples/notebooks/simple_api_fine_tuning.ipynb b/examples/notebooks/simple_api_fine_tuning.ipynb index 114602636..3dfef5908 100644 --- a/examples/notebooks/simple_api_fine_tuning.ipynb +++ b/examples/notebooks/simple_api_fine_tuning.ipynb @@ -94,7 +94,7 @@ "import os\n", "\n", "import jiant.utils.python.io as py_io\n", - "from jiant.proj.simple import runscript as run\n", + "import jiant.proj.simple.runscript as simple_run\n", "import jiant.scripts.download_data.runscript as downloader" ], "execution_count": null, @@ -146,8 +146,8 @@ }, "source": [ "RUN_NAME = f\"simple_{TASK_NAME}_{MODEL_TYPE}\"\n", - "DATA_DIR = \"/content/data\"\n", "EXP_DIR = \"/content/exp\"\n", + "DATA_DIR = \"/content/exp/tasks\"\n", "\n", "os.makedirs(DATA_DIR, exist_ok=True)\n", "os.makedirs(EXP_DIR, exist_ok=True)" @@ -196,7 +196,7 @@ "colab": {} }, "source": [ - "args = run.RunConfiguration(\n", + "args = simple_run.RunConfiguration(\n", " run_name=RUN_NAME,\n", " exp_dir=EXP_DIR,\n", " data_dir=DATA_DIR,\n", @@ -205,7 +205,7 @@ " train_batch_size=16,\n", " num_train_epochs=1\n", ")\n", - "run.run_simple(args)" + "simple_run.run_simple(args)" ], "execution_count": null, "outputs": [] @@ -228,11 +228,11 @@ "colab": {} }, "source": [ - "args = run.RunConfiguration.from_json_path(os.path.join(EXP_DIR, \"runs\", RUN_NAME, \"simple_run_config.json\"))\n", - "run.run_simple(args)" + "args = simple_run.RunConfiguration.from_json_path(os.path.join(EXP_DIR, \"runs\", RUN_NAME, \"simple_run_config.json\"))\n", + "simple_run.run_simple(args)" ], "execution_count": null, "outputs": [] } ] -} \ No newline at end of file +} diff --git a/jiant/proj/main/scripts/configurator.py b/jiant/proj/main/scripts/configurator.py index e517894ba..b4b6559a1 100644 --- a/jiant/proj/main/scripts/configurator.py +++ b/jiant/proj/main/scripts/configurator.py @@ -38,10 +38,174 @@ def cap_examples(num_examples, cap): return min(num_examples, cap) +@Registry.register +@zconf.run_config +class SingleTaskConfigurator(zconf.RunConfig): + """Single-task Configurator + + Required: + task_name + train_batch_size + + (Task config) Need one of: + task_config_path + task_config_base_path + + (Task cache) Need one of: + task_cache_path + task_cache_base_path + + (Eval batch size) Need one of: + eval_batch_multiplier + eval_batch_size + + (Computing max steps) Need one of: + epochs + max_steps + (Set to 0 if not training) + + (Task name list) Specify at least one of: + do_train + do_val + do_test + + Optional: + gradient_accumulation_steps + eval_subset_num + num_gpus + train_examples_cap + warmup_steps_proportion + """ + + task_name = zconf.attr(type=str, default=None) + task_config_base_path = zconf.attr(type=str, default=None) + task_config_path = zconf.attr(type=str, default=None) + task_cache_base_path = zconf.attr(type=str, default=None) + task_cache_path = zconf.attr(type=str, default=None) + do_train = zconf.attr(type=bool, action="store_true") + do_val = zconf.attr(type=bool, action="store_true") + do_test = zconf.attr(type=bool, action="store_true") + train_batch_size = zconf.attr(type=int, required=True) + eval_batch_multiplier = zconf.attr(type=int, default=None) + eval_batch_size = zconf.attr(type=int, default=None) + gradient_accumulation_steps = zconf.attr(type=int, default=1) + eval_subset_num = zconf.attr(type=int, default=500) + epochs = zconf.attr(type=int, default=None) + max_steps = zconf.attr(type=int, default=None) + num_gpus = zconf.attr(type=int, default=None) + warmup_steps_proportion = zconf.attr(type=float, default=0.1) + + def create_config(self): + # === Get task config === # + if self.task_config_path: + assert self.task_config_base_path is None + task_config_path = self.task_config_path + elif self.task_config_base_path is not None: + assert self.task_config_path is None + task_config_path = os.path.join( + self.task_config_base_path, f"{self.task_name}_config.json", + ) + else: + raise RuntimeError("Require either `task_config_path` or `task_config_base_path`") + + # === Get cache === # + if self.task_cache_path is not None: + assert self.task_cache_base_path is None + task_cache_path = self.task_cache_path + elif self.task_cache_base_path is not None: + assert self.task_cache_path is None + task_cache_path = os.path.join(self.task_cache_base_path, self.task_name) + else: + raise RuntimeError("Need `task_cache_path` or `task_cache_base_path`") + task_cache_config = {} + if self.do_train: + task_cache_config["train"] = os.path.join(task_cache_path, "train") + task_cache_config["train_val"] = os.path.join(task_cache_path, "train_val") + if self.do_val: + task_cache_config["val"] = os.path.join(task_cache_path, "val") + if self.do_test: + task_cache_config["test"] = os.path.join(task_cache_path, "test") + for v in task_cache_config.values(): + assert os.path.exists(v) + + # === Compute training steps === # + if not self.do_train: + assert self.epochs is None + assert self.max_steps is None + max_steps = 0 + elif self.max_steps is not None: + max_steps = self.max_steps + elif self.epochs is not None: + assert self.max_steps is None + if self.num_gpus: + # We multiply by num_gpus because 1 step is done across (potentially) multiple GPUs + effective_batch_size = ( + self.train_batch_size * self.gradient_accumulation_steps * self.num_gpus + ) + else: + effective_batch_size = self.train_batch_size * self.gradient_accumulation_steps + num_examples = get_num_examples_from_cache( + cache_path=os.path.expandvars(task_cache_config["train"]), + ) + max_steps = self.epochs * math.ceil(num_examples / effective_batch_size) + else: + raise RuntimeError("Require either `epochs` or `max_steps`") + + # === Compute eval_batch_size === # + if self.eval_batch_size is not None: + assert self.eval_batch_multiplier is None + eval_batch_size = self.eval_batch_size + elif self.eval_batch_multiplier is not None: + assert self.eval_batch_size is None + eval_batch_size = self.train_batch_size * self.eval_batch_multiplier + else: + raise RuntimeError("Require either `eval_batch_size` or `eval_batch_multiplier`") + + # === Build configuration === # + # Finally, we build our big config dictionary. Congrats! + config_dict = { + "task_config_path_dict": {self.task_name: task_config_path}, + "task_cache_config_dict": {self.task_name: task_cache_config}, + "sampler_config": {"sampler_type": "UniformMultiTaskSampler"}, + "global_train_config": { + "max_steps": int(max_steps), + "warmup_steps": int(max_steps * self.warmup_steps_proportion), + }, + "task_specific_configs_dict": { + self.task_name: { + "train_batch_size": self.train_batch_size, + "eval_batch_size": eval_batch_size, + "gradient_accumulation_steps": self.gradient_accumulation_steps, + "eval_subset_num": self.eval_subset_num, + } + }, + "taskmodels_config": { + "task_to_taskmodel_map": {self.task_name: self.task_name}, + "taskmodel_config_map": {self.task_name: None}, + }, + "task_run_config": { + "train_task_list": [self.task_name] if self.do_train else [], + "train_val_task_list": [self.task_name] if self.do_train else [], + "val_task_list": [self.task_name] if self.do_val else [], + "test_task_list": [self.task_name] if self.do_test else [], + }, + "metric_aggregator_config": {"metric_aggregator_type": "EqualMetricAggregator"}, + } + return config_dict + + @Registry.register @zconf.run_config class SimpleAPIMultiTaskConfigurator(zconf.RunConfig): """Multi-task Configurator designed for SimpleAPI + + For simplicity, we assume that certain properties are constant across all tasks: + batch sizes and eval_subset_num. + Any more complex, and the user is better off writing the config entirely on their own. + + Required: + train_batch_size + (Task config) Need one of: task_config_base_path task_config_path_dict @@ -57,6 +221,7 @@ class SimpleAPIMultiTaskConfigurator(zconf.RunConfig): (Computing max steps) Need one of: epochs max_steps + (Set to 0 if not training) (Task name list) Specify at least one of: train_task_name_list @@ -64,9 +229,6 @@ class SimpleAPIMultiTaskConfigurator(zconf.RunConfig): val_task_name_list test_task_name_list - Required: - train_batch_size - Optional: gradient_accumulation_steps eval_subset_num diff --git a/jiant/utils/python/datastructures.py b/jiant/utils/python/datastructures.py index cffcf734c..8707261ab 100644 --- a/jiant/utils/python/datastructures.py +++ b/jiant/utils/python/datastructures.py @@ -182,7 +182,7 @@ def check_keys(dict1: dict, key_list, mode="equal") -> bool: raise KeyError(mode) -def get_unique_list_in_order(list_of_lists: Sequence[Sequence]): +def get_unique_list_in_order(list_of_lists: Iterable[Sequence]): """Gets unique items from a list of lists, in the order of the appearance Args: @@ -214,6 +214,7 @@ def reorder_keys(dict1: dict, key_list: list) -> dict: """ dict_class = dict1.__class__ assert check_keys(dict1, key_list) + # noinspection PyArgumentList return dict_class([(k, dict1[k]) for k in key_list])