Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge master into new_MT_branch. #977

Merged
merged 30 commits into from
Dec 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
10fb192
Readme update for bert npi paper (#915)
HaokunLiu Sep 19, 2019
c36b74e
Fixing index problem & minor pytorch_transformers_interface cleanup (…
HaokunLiu Sep 20, 2019
706b652
Prepare for 1.2.1 release.
sleepinyourhat Sep 23, 2019
b19ca78
QA-SRL (#716)
zphang Oct 1, 2019
2553c2d
Implementing Data Parallel (#873)
Oct 1, 2019
7508bea
replace correct_sent_indexing with non inplace version (#921)
HaokunLiu Oct 2, 2019
9d4baf3
Abductive NLI (aNLI) (#922)
zphang Oct 7, 2019
254dc37
SocialIQA (#924)
Oct 10, 2019
f0ef3f7
Fixing bug with restoring checkpoint with two gpus + cleaning CUDA pa…
Oct 16, 2019
8f46d4f
Updating CoLA inference script (#931)
zphang Oct 17, 2019
303a733
Adding Senteval Tasks (#926)
Oct 21, 2019
787e78b
Speed up retokenization (#935)
Oct 22, 2019
2ed6802
Scitail (#943)
phu-pmh Oct 26, 2019
fb1eec1
Add corrected data stastistics (#941)
pitrack Oct 26, 2019
1d40f23
CommonsenseQA+hellaswag (#942)
HaokunLiu Oct 26, 2019
3b07a5e
fix name (#945)
HaokunLiu Oct 27, 2019
347f743
CCG update (#948)
HaokunLiu Nov 3, 2019
41abe5f
Fixing senteval-probing preprocessing (#951)
Nov 5, 2019
98b1dc8
Adding tokenizer alignment function (#953)
Nov 6, 2019
d769338
Function words probing (#949)
HaokunLiu Nov 8, 2019
8af068d
CosmosQA (#952)
phu-pmh Nov 9, 2019
1ee0d95
qqp fix (#956)
zphang Nov 10, 2019
2a9230b
QAMR + QA-SRL Update (#932)
zphang Nov 12, 2019
39b234e
Set _unk_id in Roberta module (#959)
njjiang Nov 14, 2019
7dc9965
Fixing load_target_train_checkpoint with mixing setting (#960)
Nov 15, 2019
daec5cf
update pytorch and numpy version requirements (#965)
pyeres Nov 20, 2019
8a059b8
CCG update (#955)
HaokunLiu Nov 22, 2019
c181273
add adversarial_nli tasks (#966)
pyeres Nov 23, 2019
42b389f
Update README.md
sleepinyourhat Nov 27, 2019
18ca100
Citation fix
sleepinyourhat Dec 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ If you use `jiant` in academic work, please cite it directly:

```
@misc{wang2019jiant,
author = {Alex Wang and Ian F. Tenney and Yada Pruksachatkun and Katherin Yu and Jan Hula and Patrick Xia and Raghu Pappagari and Shuning Jin and R. Thomas McCoy and Roma Patel and Yinghui Huang and Jason Phang and Edouard Grave and Haokun Liu and Najoung Kim and Phu Mon Htut and Thibault F'{e}vry and Berlin Chen and Nikita Nangia and Anhad Mohananey and Katharina Kann and Shikha Bordia and Nicolas Patry and David Benton and Ellie Pavlick and Samuel R. Bowman},
author = {Alex Wang and Ian F. Tenney and Yada Pruksachatkun and Katherin Yu and Jan Hula and Patrick Xia and Raghu Pappagari and Shuning Jin and R. Thomas McCoy and Roma Patel and Yinghui Huang and Jason Phang and Edouard Grave and Haokun Liu and Najoung Kim and Phu Mon Htut and Thibault F\'evry and Berlin Chen and Nikita Nangia and Anhad Mohananey and Katharina Kann and Shikha Bordia and Nicolas Patry and David Benton and Ellie Pavlick and Samuel R. Bowman},
title = {\texttt{jiant} 1.2: A software toolkit for research on general-purpose text understanding models},
howpublished = {\url{http://jiant.info/}},
year = {2019}
Expand All @@ -57,13 +57,15 @@ If you use `jiant` in academic work, please cite it directly:
- [What do you learn from context? Probing for sentence structure in contextualized word representations](https://openreview.net/forum?id=SJzSgnRcKX) ("edge probing")
- [BERT Rediscovers the Classical NLP Pipeline](https://arxiv.org/abs/1905.05950) ("BERT layer paper")
- [Probing What Different NLP Tasks Teach Machines about Function Word Comprehension](https://arxiv.org/abs/1904.11544) ("function word probing")
- [Investigating BERT’s Knowledge of Language: Five Analysis Methods with NPIs](https://arxiv.org/abs/1909.02597) ("BERT NPI paper")

To exactly reproduce experiments from [the ELMo's Friends paper](https://arxiv.org/abs/1812.10860) use the [`jsalt-experiments`](https://github.com/jsalt18-sentence-repl/jiant/tree/jsalt-experiments) branch. That will contain a snapshot of the code as of early August, potentially with updated documentation.

For the [edge probing paper](https://openreview.net/forum?id=SJzSgnRcKX) and the [BERT layer paper](https://arxiv.org/abs/1905.05950), see the [probing/](probing/) directory.

For the [function word probing paper](https://arxiv.org/abs/1904.11544), use [this branch](https://github.com/nyu-mll/jiant/tree/naacl_probingpaper) and refer to the instructions in the [scripts/fwords/](https://github.com/nyu-mll/jiant/tree/naacl_probingpaper/scripts/fwords) directory.

For the [BERT NPI paper](https://arxiv.org/abs/1909.02597) follow the instructions in [scripts/bert_npi](https://github.com/nyu-mll/jiant/tree/blimp-and-npi/scripts/bert_npi) on the [`blimp-and-npi`](https://github.com/nyu-mll/jiant/tree/blimp-and-npi) branch.

## Getting Help

Expand Down Expand Up @@ -106,7 +108,7 @@ This package is released under the [MIT License](LICENSE.md). The material in th

- Part of the development of `jiant` took at the 2018 Frederick Jelinek Memorial Summer Workshop on Speech and Language Technologies, and was supported by Johns Hopkins University with unrestricted gifts from Amazon, Facebook, Google, Microsoft and Mitsubishi Electric Research Laboratories.
- This work was made possible in part by a donation to NYU from Eric and Wendy Schmidt made
by recommendation of the Schmidt Futures program.
by recommendation of the Schmidt Futures program, and by support from Intuit Inc.
- We gratefully acknowledge the support of NVIDIA Corporation with the donation of a Titan V GPU used at NYU in this work.
- Developer Alex Wang is supported by the National Science Foundation Graduate Research Fellowship Program under Grant
No. DGE 1342536. Any opinions, findings, and conclusions or recommendations expressed in this
Expand Down
41 changes: 30 additions & 11 deletions cola_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,14 @@
from tqdm import tqdm

from jiant.models import build_model
from jiant.preprocess import build_indexers, build_tasks
from jiant.preprocess import build_indexers, build_tasks, ModelPreprocessingInterface
from jiant.tasks.tasks import tokenize_and_truncate, sentence_to_text_field
from jiant.utils import config
from jiant.utils.data_loaders import load_tsv
from jiant.utils.utils import check_arg_name, load_model_state, select_pool_type
from jiant.utils.utils import load_model_state, select_pool_type
from jiant.utils.options import parse_cuda_list_arg
from jiant.utils.tokenizers import select_tokenizer
from jiant.__main__ import check_arg_name

log.basicConfig(format="%(asctime)s: %(message)s", datefmt="%m/%d %I:%M:%S %p", level=log.INFO)

Expand Down Expand Up @@ -140,7 +143,7 @@ def main(cl_arguments):
args.cuda = -1

if args.tokenizer == "auto":
args.tokenizer = tokenizers.select_tokenizer(args)
args.tokenizer = select_tokenizer(args)
if args.pool_type == "auto":
args.pool_type = select_pool_type(args)

Expand All @@ -149,7 +152,8 @@ def main(cl_arguments):
tasks = sorted(set(target_tasks), key=lambda x: x.name)

# Build or load model #
model = build_model(args, vocab, word_embs, tasks)
cuda_device = parse_cuda_list_arg(args.cuda)
model = build_model(args, vocab, word_embs, tasks, cuda_device)
log.info("Loading existing model from %s...", cl_args.model_file_path)
load_model_state(model, cl_args.model_file_path, args.cuda, [], strict=False)

Expand All @@ -158,16 +162,18 @@ def main(cl_arguments):
vocab = Vocabulary.from_files(os.path.join(args.exp_dir, "vocab"))
indexers = build_indexers(args)
task = take_one(tasks)
model_preprocessing_interface = ModelPreprocessingInterface(args)

# Run Inference #
if cl_args.inference_mode == "repl":
assert cl_args.input_path is None
assert cl_args.output_path is None
print("Running REPL for task: {}".format(task.name))
run_repl(model, vocab, indexers, task, args)
run_repl(model, model_preprocessing_interface, vocab, indexers, task, args)
elif cl_args.inference_mode == "corpus":
run_corpus_inference(
model,
model_preprocessing_interface,
vocab,
indexers,
task,
Expand All @@ -181,7 +187,7 @@ def main(cl_arguments):
raise KeyError(cl_args.inference_mode)


def run_repl(model, vocab, indexers, task, args):
def run_repl(model, model_preprocessing_interface, vocab, indexers, task, args):
""" Run REPL """
print("Input CTRL-C or enter 'QUIT' to terminate.")
while True:
Expand All @@ -195,7 +201,9 @@ def run_repl(model, vocab, indexers, task, args):
tokenizer_name=task.tokenizer_name, sent=input_string, max_seq_len=args.max_seq_len
)
print("TOKENS:", " ".join("[{}]".format(tok) for tok in tokens))
field = sentence_to_text_field(tokens, indexers)
field = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(tokens), indexers
)
field.index(vocab)
batch = Batch([Instance({"input1": field})]).as_tensor_dict()
batch = move_to_device(batch, args.cuda)
Expand All @@ -217,13 +225,22 @@ def run_repl(model, vocab, indexers, task, args):


def run_corpus_inference(
model, vocab, indexers, task, args, input_path, input_format, output_path, eval_output_path
model,
model_preprocessing_interface,
vocab,
indexers,
task,
args,
input_path,
input_format,
output_path,
eval_output_path,
):
""" Run inference on corpus """
tokens, labels = load_cola_data(input_path, task, input_format, args.max_seq_len)
logit_batches = []
for tokens_batch in tqdm(list(batchify(tokens, args.batch_size))):
batch, _ = prepare_batch(tokens_batch, vocab, indexers, args)
batch, _ = prepare_batch(model_preprocessing_interface, tokens_batch, vocab, indexers, args)
with torch.no_grad():
out = model.forward(task, batch, predict=True)
logit_batches.append(out["logits"].cpu().numpy())
Expand Down Expand Up @@ -263,12 +280,14 @@ def batchify(ls, batch_size):
i += batch_size


def prepare_batch(tokens_batch, vocab, indexers, args):
def prepare_batch(model_preprocessing_interface, tokens_batch, vocab, indexers, args):
""" Do preprocessing for batch """
instance_ls = []
token_ls = []
for tokens in tokens_batch:
field = sentence_to_text_field(tokens, indexers)
field = sentence_to_text_field(
model_preprocessing_interface.boundary_token_fn(tokens), indexers
)
field.index(vocab)
instance_ls.append(Instance({"input1": field}))
token_ls.append(tokens)
Expand Down
4 changes: 2 additions & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ channels:
- pytorch
dependencies:
- python=3.6
- pytorch=1.0.0
- pytorch=1.1.0
- torchvision=0.2.1
- numpy=1.14.5
- numpy=1.15.0
- scikit-learn=0.19.1
- pandas=0.23.0
# bokeh for plotting
Expand Down
41 changes: 26 additions & 15 deletions jiant/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Train a multi-task model using AllenNLP
"""Main flow for jiant.

To debug this, run with -m ipdb:

Expand All @@ -21,13 +21,15 @@
import time
import copy
import torch
import torch.nn as nn

from jiant import evaluate
from jiant.models import build_model
from jiant.preprocess import build_tasks
from jiant import tasks as task_modules
from jiant.trainer import build_trainer
from jiant.utils import config, tokenizers
from jiant.utils.options import parse_cuda_list_arg
from jiant.utils.utils import (
assert_for_log,
load_model_state,
Expand All @@ -38,6 +40,8 @@
check_for_previous_checkpoints,
select_pool_type,
delete_all_checkpoints,
get_model_attribute,
uses_cuda,
)


Expand Down Expand Up @@ -302,15 +306,15 @@ def get_best_checkpoint_path(args, phase, task_name=None):
return None


def evaluate_and_write(args, model, tasks, splits_to_write):
def evaluate_and_write(args, model, tasks, splits_to_write, cuda_device):
""" Evaluate a model on dev and/or test, then write predictions """
val_results, val_preds = evaluate.evaluate(model, tasks, args.batch_size, args.cuda, "val")
val_results, val_preds = evaluate.evaluate(model, tasks, args.batch_size, cuda_device, "val")
if "val" in splits_to_write:
evaluate.write_preds(
tasks, val_preds, args.run_dir, "val", strict_glue_format=args.write_strict_glue_format
)
if "test" in splits_to_write:
_, te_preds = evaluate.evaluate(model, tasks, args.batch_size, args.cuda, "test")
_, te_preds = evaluate.evaluate(model, tasks, args.batch_size, cuda_device, "test")
evaluate.write_preds(
tasks, te_preds, args.run_dir, "test", strict_glue_format=args.write_strict_glue_format
)
Expand Down Expand Up @@ -377,7 +381,8 @@ def initial_setup(args, cl_args):
random.seed(seed)
torch.manual_seed(seed)
log.info("Using random seed %d", seed)
if args.cuda >= 0:
if isinstance(args.cuda, int) and args.cuda >= 0:
# If only running on one GPU.
try:
if not torch.cuda.is_available():
raise EnvironmentError("CUDA is not available, or not detected" " by PyTorch.")
Expand Down Expand Up @@ -443,7 +448,7 @@ def check_arg_name(args):
)


def load_model_for_target_train_run(args, ckpt_path, model, strict, task):
def load_model_for_target_train_run(args, ckpt_path, model, strict, task, cuda_devices):
"""
Function that reloads model if necessary and extracts trainable parts
of the model in preparation for target_task training.
Expand All @@ -462,9 +467,8 @@ def load_model_for_target_train_run(args, ckpt_path, model, strict, task):
to_train: List of tuples of (name, weight) of trainable parameters

"""

load_model_state(model, ckpt_path, cuda_devices, skip_task_models=[task.name], strict=strict)
if args.transfer_paradigm == "finetune":
load_model_state(model, ckpt_path, args.cuda, skip_task_models=[task.name], strict=strict)
# Train both the task specific models as well as sentence encoder.
to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
else: # args.transfer_paradigm == "frozen":
Expand All @@ -483,9 +487,13 @@ def load_model_for_target_train_run(args, ckpt_path, model, strict, task):
"they should not be updated! Check sep_embs_for_skip flag or make an issue.",
)
# Only train task-specific module
pred_module = getattr(model, "%s_mdl" % task.name)

pred_module = get_model_attribute(model, "%s_mdl" % task.name, cuda_devices)
to_train = [(n, p) for n, p in pred_module.named_parameters() if p.requires_grad]
to_train += elmo_scalars
model = model.cuda() if uses_cuda(cuda_devices) else model
if isinstance(cuda_devices, list):
model = nn.DataParallel(model, device_ids=cuda_devices)
return to_train


Expand All @@ -499,6 +507,7 @@ def main(cl_arguments):
# Load tasks
log.info("Loading tasks...")
start_time = time.time()
cuda_device = parse_cuda_list_arg(args.cuda)
pretrain_tasks, target_tasks, vocab, word_embs = build_tasks(args)
tasks = sorted(set(pretrain_tasks + target_tasks), key=lambda x: x.name)
log.info("\tFinished loading tasks in %.3fs", time.time() - start_time)
Expand All @@ -507,7 +516,7 @@ def main(cl_arguments):
# Build model
log.info("Building model...")
start_time = time.time()
model = build_model(args, vocab, word_embs, tasks)
model = build_model(args, vocab, word_embs, tasks, cuda_device)
log.info("Finished building model in %.3fs", time.time() - start_time)

# Start Tensorboard if requested
Expand All @@ -524,7 +533,7 @@ def main(cl_arguments):
pretrain_tasks[0].val_metric_decreases if len(pretrain_tasks) == 1 else False
)
trainer, _, opt_params, schd_params = build_trainer(
args, [], model, args.run_dir, should_decrease, phase="pretrain"
args, cuda_device, [], model, args.run_dir, should_decrease, phase="pretrain"
)
to_train = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
_ = trainer.train(
Expand Down Expand Up @@ -565,10 +574,11 @@ def main(cl_arguments):
continue

params_to_train = load_model_for_target_train_run(
args, pre_target_train_path, model, strict, task
args, pre_target_train_path, model, strict, task, cuda_device
)
trainer, _, opt_params, schd_params = build_trainer(
args,
cuda_device,
[task.name],
model,
args.run_dir,
Expand Down Expand Up @@ -596,11 +606,12 @@ def main(cl_arguments):
# Evaluate on target_tasks.
for task in target_tasks:
# Find the task-specific best checkpoint to evaluate on.
task_to_use = model._get_task_params(task.name).get("use_classifier", task.name)
task_params = get_model_attribute(model, "_get_task_params", cuda_device)
task_to_use = task_params(task.name).get("use_classifier", task.name)
ckpt_path = get_best_checkpoint_path(args, "eval", task_to_use)
assert ckpt_path is not None
load_model_state(model, ckpt_path, args.cuda, skip_task_models=[], strict=strict)
evaluate_and_write(args, model, [task], splits_to_write)
load_model_state(model, ckpt_path, cuda_device, skip_task_models=[], strict=strict)
evaluate_and_write(args, model, [task], splits_to_write, cuda_device)

if args.delete_checkpoints_when_done and not args.keep_all_checkpoints:
log.info("Deleting all checkpoints.")
Expand Down
3 changes: 2 additions & 1 deletion jiant/config/defaults.conf
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@

// Misc. Logistics //

cuda = 0 // GPU ID. Set to -1 for CPU. On machines without GPUs, this is ignored.
cuda = auto // GPU ID. Set to -1 for CPU, "auto" for all available GPUs on machine and
// a comma-delimited list of GPU IDs for a subset of GPUs.
random_seed = 1234 // Global random seed, used in both Python and PyTorch random number generators.
track_batch_utilization = 0 // Track % of each batch that is padding tokens (for tasks with field
// 'input1').
Expand Down
1 change: 0 additions & 1 deletion jiant/config/demo.conf
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ include "defaults.conf" // relative path to this file
exp_name = jiant-demo
run_name = mtl-sst-mrpc

cuda = 0
random_seed = 42

load_model = 0
Expand Down
Loading