-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hard negative mining for Retriever fine-tuning #523
base: main
Are you sure you want to change the base?
Changes from all commits
e5a88c9
f6a47f9
971df10
8a6deb9
75c869e
547c849
cf0ec14
d9f7be3
d9ee0ee
0e69b74
d5d051a
f512886
794ef20
531b5ac
2082f16
bb9ed81
94800e5
1b122f0
6acaa62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# HARD-NEGATIVE MINING parameters | ||
|
||
# base_url: "https://integrate.api.nvidia.com/v1" | ||
model_name: "sentence-transformers/all-MiniLM-L6-v2" | ||
model_type: "hf" | ||
query_prefix: "query:" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a reason why the defaults in the dataclass don't match the defaults in the yaml? |
||
passage_prefix: "passage:" | ||
api_key: "your api key here" | ||
truncate: "END" | ||
hard_negatives_to_mine: 4 | ||
hard_neg_mining_algorithm: "topk_percpos" | ||
percpos: 0.95 | ||
|
||
|
||
# SEMANTIC CLUSTERING parameters | ||
min_cluster_size: 100 | ||
max_number_clusters: 200 | ||
cluster_output_dir: "/workspace/hnm/clusters" | ||
logger_output_dir: "/workspace/hnm/logs" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import glob | ||
import importlib | ||
import os | ||
import pdb | ||
import shutil | ||
import time | ||
from pathlib import Path | ||
from typing import Any, List | ||
|
||
from retriever_hardnegative_miner import HardNegativeMiner | ||
from tqdm.dask import TqdmCallback | ||
|
||
from config.config import RetrieverHardNegativeMiningConfig | ||
from nemo_curator.datasets import DocumentDataset | ||
from nemo_curator.utils.distributed_utils import get_client | ||
from nemo_curator.utils.file_utils import get_all_files_paths_under | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--input-dir", | ||
type=str, | ||
default="", | ||
help="Input dir path containing annotated data files in jsonl format", | ||
) | ||
parser.add_argument( | ||
"--hard-negative-mining-config", | ||
type=str, | ||
default="", | ||
help="Configuration yaml file path containing config for hard negative mining", | ||
) | ||
parser.add_argument( | ||
"--output-dir", | ||
type=str, | ||
default="", | ||
help="Output file containing hard negatives", | ||
) | ||
parser.add_argument( | ||
"--api-key", | ||
type=str, | ||
default=None, | ||
help="The API key to use for the synthetic data generation LLM client.", | ||
) | ||
parser.add_argument( | ||
"--api-timeout", | ||
type=int, | ||
default=120, | ||
help="The timeout value for API calls in seconds.", | ||
) | ||
args = parser.parse_args() | ||
|
||
if not os.path.exists(args.input_dir): | ||
raise ValueError("Input directory not found") | ||
|
||
if os.path.exists(args.output_dir): | ||
raise ValueError("Output dir exists already, use a new file name!") | ||
|
||
if args.input_dir: | ||
input_files = get_all_files_paths_under(args.input_dir, keep_extensions="part") | ||
input_dataset = DocumentDataset.read_json(input_files) | ||
else: | ||
raise ValueError("provide input file path") | ||
|
||
if args.hard_negative_mining_config: | ||
cfg = RetrieverHardNegativeMiningConfig.from_yaml( | ||
args.hard_negative_mining_config | ||
) | ||
|
||
else: | ||
raise ValueError("provide config for hard negative mining") | ||
if args.api_key: | ||
cfg.api_key = args.api_key | ||
|
||
mine_hard_negatives = HardNegativeMiner(cfg) | ||
print("Mining hard negatives ...") | ||
st_time = time.time() | ||
mined_dataset = mine_hard_negatives(input_dataset) | ||
|
||
print("Time taken = {:.2f} s".format(time.time() - st_time)) | ||
print("Saving data in jsonl format ...") | ||
mined_dataset.df.to_json( | ||
os.path.join(args.output_dir), lines=True, orient="records" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
dask_client = get_client(cluster_type="gpu") | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import argparse | ||
import glob | ||
import importlib | ||
import os | ||
import pdb | ||
import shutil | ||
import time | ||
from typing import Any, List | ||
|
||
from retriever_hardnegative_miner import HardNegativeMiner | ||
from tqdm.dask import TqdmCallback | ||
|
||
from config.config import RetrieverHardNegativeMiningConfig | ||
from nemo_curator.datasets import DocumentDataset | ||
from nemo_curator.utils.distributed_utils import get_client | ||
from nemo_curator.utils.file_utils import get_all_files_paths_under | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--input-dir", | ||
type=str, | ||
default="", | ||
help="Input dir path containing annotated data files in jsonl format", | ||
) | ||
parser.add_argument( | ||
"--hard-negative-mining-config", | ||
type=str, | ||
default="", | ||
help="Configuration yaml file path containing config for hard negative mining", | ||
) | ||
parser.add_argument( | ||
"--output-dir", | ||
type=str, | ||
default="", | ||
help="Output file containing clustered dataset", | ||
) | ||
parser.add_argument( | ||
"--api-key", | ||
type=str, | ||
default=None, | ||
help="The API key to use for the synthetic data generation LLM client.", | ||
) | ||
parser.add_argument( | ||
"--api-timeout", | ||
type=int, | ||
default=120, | ||
help="The timeout value for API calls in seconds.", | ||
) | ||
args = parser.parse_args() | ||
|
||
if not os.path.exists(args.input_dir): | ||
raise ValueError("Input directory not found") | ||
|
||
if os.path.exists(args.output_dir): | ||
raise ValueError("Output dir exists already, use a new file name!") | ||
|
||
if args.input_dir: | ||
input_files = get_all_files_paths_under(args.input_dir, keep_extensions="jsonl") | ||
input_dataset = DocumentDataset.read_json(input_files) | ||
else: | ||
raise ValueError("provide input file path") | ||
if args.hard_negative_mining_config: | ||
cfg = RetrieverHardNegativeMiningConfig.from_yaml( | ||
args.hard_negative_mining_config | ||
) | ||
else: | ||
raise ValueError("provide config for hard negative mining") | ||
if args.api_key: | ||
cfg.api_key = args.api_key | ||
|
||
st_time = time.time() | ||
miner = HardNegativeMiner(cfg) | ||
clustered_dataset = miner.repartition_semantic_similarity(input_dataset) | ||
clustered_dataset.persist() | ||
|
||
# saving clustered dataset | ||
print("saving clustered dataset") | ||
clustered_dataset.df.to_json(os.path.join(args.output_dir, "clustered_dataset")) | ||
print("Time taken to cluster data = {:.2f} s".format(time.time() - st_time)) | ||
|
||
|
||
if __name__ == "__main__": | ||
dask_client = get_client(cluster_type="gpu") | ||
main() |
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,307 @@ | ||||||||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||||||||
# | ||||||||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||
# you may not use this file except in compliance with the License. | ||||||||
# You may obtain a copy of the License at | ||||||||
# | ||||||||
# http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
# | ||||||||
# Unless required by applicable law or agreed to in writing, software | ||||||||
# distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
# See the License for the specific language governing permissions and | ||||||||
# limitations under the License. | ||||||||
|
||||||||
import importlib | ||||||||
import itertools | ||||||||
import pdb | ||||||||
|
||||||||
import dask.dataframe as dd | ||||||||
import numpy as np | ||||||||
import pandas as pd | ||||||||
from dask.base import normalize_token, tokenize | ||||||||
from dask.diagnostics import ProgressBar | ||||||||
from openai import OpenAI | ||||||||
from sentence_transformers import SentenceTransformer | ||||||||
|
||||||||
from nemo_curator import ClusteringModel | ||||||||
from nemo_curator.datasets import DocumentDataset | ||||||||
from nemo_curator.utils.distributed_utils import NoWorkerError, load_object_on_worker | ||||||||
|
||||||||
config = importlib.import_module( | ||||||||
"tutorials.nemo-retriever-synthetic-data-generation.config.config" | ||||||||
) | ||||||||
RetrieverHardNegativeMiningConfig = config.RetrieverHardNegativeMiningConfig | ||||||||
|
||||||||
|
||||||||
def create_nim_client(base_url, api_key): | ||||||||
openai_client = OpenAI(base_url=base_url, api_key=api_key) | ||||||||
return openai_client | ||||||||
|
||||||||
|
||||||||
def create_hf_model(model_name_or_path): | ||||||||
return SentenceTransformer(model_name_or_path, trust_remote_code=True) | ||||||||
|
||||||||
|
||||||||
class HardNegativeMiner: | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add docstring. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This seems like a general enough module that we could move it into |
||||||||
|
||||||||
def __init__( | ||||||||
self, | ||||||||
cfg: RetrieverHardNegativeMiningConfig, | ||||||||
): | ||||||||
|
||||||||
self.model_name = cfg.model_name | ||||||||
self.model_type = cfg.model_type | ||||||||
self.base_url = cfg.base_url | ||||||||
self.api_key = cfg.api_key | ||||||||
self.truncate = cfg.truncate | ||||||||
self.n_hard_negatives = cfg.hard_negatives_to_mine | ||||||||
|
||||||||
if cfg.passage_prefix: | ||||||||
self.passage_prefix = cfg.passage_prefix | ||||||||
if cfg.query_prefix: | ||||||||
self.query_prefix = cfg.query_prefix | ||||||||
if cfg.hard_neg_mining_algorithm: | ||||||||
self.hard_neg_mining_algorithm = cfg.hard_neg_mining_algorithm | ||||||||
else: | ||||||||
print( | ||||||||
"hard negative mining algorithm not mentioned in config, using default" | ||||||||
) | ||||||||
self.hard_neg_mining_algorithm = "topk_percpos" | ||||||||
if self.hard_neg_mining_algorithm == "topk_percpos": | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we abstract the algorithms and their arguments into a separate class? Follow a similar pattern to the common crawl extraction algorithms. We should have one base If you feel like too much of the state of the |
||||||||
if cfg.percpos: | ||||||||
self.percpos = cfg.percpos | ||||||||
else: | ||||||||
self.percpos = 0.95 | ||||||||
elif self.hard_neg_mining_algorithm == "topk_abs": | ||||||||
if cfg.max_hardness_threshold: | ||||||||
self.max_neg_score_threshold = cfg.max_hardness_threshold | ||||||||
else: | ||||||||
raise ValueError("Hard negative threshold is required!") | ||||||||
if cfg.min_hardness_threshold: | ||||||||
self.min_neg_score_threshold = cfg.min_hardness_threshold | ||||||||
else: | ||||||||
self.min_neg_score_threshold = 0.0 | ||||||||
if cfg.min_cluster_size: | ||||||||
self.min_cluster_size = cfg.min_cluster_size | ||||||||
if cfg.max_number_clusters: | ||||||||
self.max_number_clusters = cfg.max_number_clusters | ||||||||
if cfg.cluster_output_dir: | ||||||||
self.cluster_output_dir = cfg.cluster_output_dir | ||||||||
if cfg.logger_output_dir: | ||||||||
self.logger_output_dir = cfg.logger_output_dir | ||||||||
|
||||||||
def __dask_tokenize__(self): | ||||||||
return normalize_token(HardNegativeMiner) | ||||||||
|
||||||||
def assign_ids(self, partition): | ||||||||
return partition.assign(doc_id=np.arange(len(partition)) + partition.index[0]) | ||||||||
|
||||||||
def repartition_semantic_similarity( | ||||||||
self, dataset: DocumentDataset | ||||||||
) -> DocumentDataset: | ||||||||
df = dataset.df | ||||||||
n_data = df.compute().shape[0] # number of row items | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This will make sure we dont do double computation here and only pull in shape
Suggested change
|
||||||||
n_clusters = int(np.floor(n_data / self.min_cluster_size) + 1) | ||||||||
n_clusters = min(n_clusters, self.max_number_clusters) | ||||||||
print("Number of clusters used = {}".format(n_clusters)) | ||||||||
assert "doc_id" not in df.columns | ||||||||
df["embeddings"] = "" # refers to document embeddings | ||||||||
df = df.explode("documents") | ||||||||
df = df.map_partitions(self._get_doc_embeddings, meta=df) | ||||||||
# df = dd.from_pandas(pdf) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove commented out code. |
||||||||
df = df.map_partitions(self.assign_ids) | ||||||||
embeddings_dataset = DocumentDataset(df) | ||||||||
self.clustering_model = ClusteringModel( | ||||||||
id_column="doc_id", | ||||||||
max_iter=100, | ||||||||
n_clusters=n_clusters, | ||||||||
clustering_output_dir=self.cluster_output_dir, | ||||||||
logger=self.logger_output_dir, | ||||||||
) | ||||||||
clustered_dataset = self.clustering_model(embeddings_dataset) | ||||||||
df_c = clustered_dataset.df | ||||||||
df_c = df_c[["documents", "question"]] | ||||||||
|
||||||||
return DocumentDataset(df_c) | ||||||||
|
||||||||
def _get_doc_embeddings(self, p_df: pd.DataFrame): | ||||||||
|
||||||||
if self.model_type == "nvidia": | ||||||||
self.client = load_object_on_worker( | ||||||||
attr="nim_embedding_model", | ||||||||
load_object_function=create_nim_client, | ||||||||
load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key}, | ||||||||
) | ||||||||
# p_df["embeddings"] = p_df["documents"].map( | ||||||||
# lambda pgs: [self._get_nim_embedding(t, "passage") for t in pgs] | ||||||||
# ) | ||||||||
p_df["embeddings"] = p_df["documents"].map( | ||||||||
lambda t: self._get_nim_embedding(t, "passage") | ||||||||
) | ||||||||
elif self.model_type == "hf": | ||||||||
self.hf_model = load_object_on_worker( | ||||||||
attr="hf_embedding_model", | ||||||||
load_object_function=create_hf_model, | ||||||||
load_object_kwargs={"model_name_or_path": self.model_name}, | ||||||||
) | ||||||||
# p_df["embeddings"] = p_df["documents"].map( | ||||||||
# lambda pgs: [ | ||||||||
# self._get_hf_embedding(t, self.passage_prefix) for t in pgs | ||||||||
# ] | ||||||||
# ) | ||||||||
p_df["embeddings"] = p_df["documents"].map( | ||||||||
lambda t: self._get_hf_embedding(t, self.passage_prefix) | ||||||||
) | ||||||||
return p_df | ||||||||
|
||||||||
def _groupby_question(self, pdf): | ||||||||
pdf2 = pdf.groupby("question").agg({"documents": set}) | ||||||||
pdf2["documents"] = pdf2["documents"].map(lambda x: list(x)) | ||||||||
return pdf2 | ||||||||
|
||||||||
def __call__(self, dataset: DocumentDataset) -> DocumentDataset: | ||||||||
|
||||||||
df = dataset.df | ||||||||
df = df.to_backend("pandas") | ||||||||
df = df[["question", "documents"]] | ||||||||
df = df.map_partitions(self._groupby_question).reset_index() | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you confident that all the questions will be in the same partition, or should you do a shuffle first? |
||||||||
print("Number partitions in dataset = {}".format(df.npartitions)) | ||||||||
|
||||||||
df["neg_doc_scores"] = "" | ||||||||
df["neg_doc"] = "" | ||||||||
df["doc_embed"] = "" | ||||||||
df["query_embed"] = "" | ||||||||
df["min_pos_score"] = "" | ||||||||
|
||||||||
with ProgressBar(dt=1): | ||||||||
df = df.map_partitions(self._process_partition, meta=df).compute() | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not call compute. It's up to the user when to call compute outside of the modules. |
||||||||
|
||||||||
df = df.rename(columns={"documents": "pos_doc"}) | ||||||||
df = df[["question", "pos_doc", "neg_doc"]] | ||||||||
|
||||||||
return DocumentDataset(dd.from_pandas(df)) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are you doing |
||||||||
|
||||||||
def _process_partition(self, df_p: pd.DataFrame): | ||||||||
|
||||||||
if self.model_type == "nvidia": | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we change this model_type name to
Suggested change
|
||||||||
self.client = load_object_on_worker( | ||||||||
attr="nim_embedding_model", | ||||||||
load_object_function=create_nim_client, | ||||||||
load_object_kwargs={"base_url": self.base_url, "api_key": self.api_key}, | ||||||||
) | ||||||||
df_p["doc_embed"] = df_p["documents"].map( | ||||||||
lambda pgs: [self._get_nim_embedding(t, "passage") for t in pgs] | ||||||||
) | ||||||||
df_p["query_embed"] = df_p["question"].map( | ||||||||
lambda x: self._get_nim_embedding(x, "query") | ||||||||
) | ||||||||
elif self.model_type == "hf": | ||||||||
self.hf_model = load_object_on_worker( | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @VibhuJawa should this be using crossfit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it can be, crossfit has support for these kind of embeddings . This will need some re-work to make that happen. I am not sure about time constraints at play here |
||||||||
attr="hf_embedding_model", | ||||||||
load_object_function=create_hf_model, | ||||||||
load_object_kwargs={"model_name_or_path": self.model_name}, | ||||||||
) | ||||||||
df_p["doc_embed"] = df_p["documents"].map( | ||||||||
lambda pgs: [ | ||||||||
self._get_hf_embedding(t, self.passage_prefix) for t in pgs | ||||||||
] | ||||||||
) | ||||||||
df_p["query_embed"] = df_p["question"].map( | ||||||||
lambda x: self._get_hf_embedding(x, self.query_prefix) | ||||||||
) | ||||||||
|
||||||||
doc_embeds = list(itertools.chain(*df_p["doc_embed"].to_list())) | ||||||||
docs = list(itertools.chain(*df_p["documents"].to_list())) | ||||||||
|
||||||||
if self.hard_neg_mining_algorithm == "topk_abs": | ||||||||
df_p["neg_doc_scores"] = df_p[["query_embed", "documents"]].apply( | ||||||||
lambda row: self._get_scores_topk_abs( | ||||||||
row["query_embed"], doc_embeds, docs, row["documents"] | ||||||||
), | ||||||||
axis=1, | ||||||||
) | ||||||||
|
||||||||
elif self.hard_neg_mining_algorithm == "topk_percpos": | ||||||||
df_p["min_pos_score"] = df_p[["query_embed", "doc_embed"]].apply( | ||||||||
lambda row: self._get_min_pos_score(row), axis=1 | ||||||||
) | ||||||||
df_p["neg_doc_scores"] = df_p[["query_embed", "min_pos_score"]].apply( | ||||||||
lambda row: self._get_scores_topk_percpos(row, doc_embeds, docs), axis=1 | ||||||||
) | ||||||||
|
||||||||
df_p["neg_doc"] = df_p["neg_doc_scores"].map( | ||||||||
lambda x: [doc for doc, score in x] | ||||||||
) | ||||||||
return df_p | ||||||||
|
||||||||
def _get_min_pos_score(self, row): | ||||||||
x_ = np.array(row["query_embed"]) | ||||||||
y_ = np.array(row["doc_embed"]) # document embeddings | ||||||||
scores = np.dot(x_, y_.T) | ||||||||
return np.min(scores) | ||||||||
|
||||||||
def _get_scores_topk_percpos(self, row, docs_embed, docs): | ||||||||
x_ = np.array(row["query_embed"]) | ||||||||
y_ = np.array(docs_embed) | ||||||||
scores = np.dot(x_, y_.T) | ||||||||
neg_docs = [] | ||||||||
neg_docs_scores = [] | ||||||||
max_neg_score_threshold = row["min_pos_score"] * self.percpos | ||||||||
for idx, s in enumerate(scores): | ||||||||
if s <= max_neg_score_threshold: | ||||||||
if docs[idx] not in neg_docs: | ||||||||
neg_docs.append(docs[idx]) | ||||||||
neg_docs_scores.append((docs[idx], s)) | ||||||||
del neg_docs, scores | ||||||||
return sorted(neg_docs_scores, reverse=True, key=lambda x: x[1])[ | ||||||||
: self.n_hard_negatives | ||||||||
] | ||||||||
|
||||||||
def _get_scores_topk_abs(self, x, docs_embed, docs, pos_docs): | ||||||||
x_ = np.array(x) | ||||||||
y_ = np.array(docs_embed) | ||||||||
scores = np.dot(x_, y_.T) | ||||||||
neg_docs = [] | ||||||||
neg_docs_scores = [] | ||||||||
for idx, s in enumerate(scores): | ||||||||
if s <= self.max_neg_score_threshold: | ||||||||
if docs[idx] not in pos_docs: | ||||||||
if docs[idx] not in neg_docs: | ||||||||
neg_docs.append(docs[idx]) | ||||||||
neg_docs_scores.append((docs[idx], s)) | ||||||||
del neg_docs, scores | ||||||||
return sorted(neg_docs_scores, reverse=True, key=lambda x: x[1])[ | ||||||||
: self.n_hard_negatives | ||||||||
] | ||||||||
|
||||||||
def _get_hf_embedding(self, text, prefix="query"): | ||||||||
embeddings = self.hf_model.encode(prefix + text) | ||||||||
return embeddings | ||||||||
|
||||||||
def _get_nim_embedding(self, text, input_type): | ||||||||
# Obtain embeddings from nim model | ||||||||
if isinstance(text, list): | ||||||||
input_ = text | ||||||||
elif isinstance(text, str): | ||||||||
input_ = [text] | ||||||||
|
||||||||
try: | ||||||||
response = self.client.embeddings.create( | ||||||||
input=input_, | ||||||||
model=self.model_name, | ||||||||
encoding_format="float", | ||||||||
extra_body={"input_type": input_type, "truncate": self.truncate}, | ||||||||
) | ||||||||
except Exception as e: | ||||||||
print(f"Error: {e}") | ||||||||
response = None | ||||||||
|
||||||||
if response: | ||||||||
if isinstance(text, list): | ||||||||
embeddings = [r.embedding for r in response.data] | ||||||||
elif isinstance(text, str): | ||||||||
embeddings = response.data[0].embedding | ||||||||
return embeddings | ||||||||
else: | ||||||||
return [] |
Large diffs are not rendered by default.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why was this change made?