Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create Cache class for exact, fuzzy, and semantic deduplication #384

Open
wants to merge 34 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
769e2ea
add global cache variable and use it for exact dedup
sarahyurick Nov 19, 2024
b77139c
global cache for semdedup
sarahyurick Nov 19, 2024
337cec8
run black and modify pytest
sarahyurick Nov 19, 2024
6d55d8c
update image notebook
sarahyurick Nov 20, 2024
622912b
Merge branch 'main' into global_cache_dir
sarahyurick Nov 20, 2024
4cb26d5
save fuzzy dedup progress
sarahyurick Nov 20, 2024
b001622
save progress
sarahyurick Nov 20, 2024
0c14626
update remaining docs
sarahyurick Nov 20, 2024
7486459
run black
sarahyurick Nov 20, 2024
053f312
Merge branch 'main' into global_cache_dir
sarahyurick Dec 6, 2024
1b1ba30
Merge branch 'main' into global_cache_dir
sarahyurick Dec 11, 2024
4b12651
Merge branch 'main' into global_cache_dir
sarahyurick Dec 17, 2024
4160471
Merge branch 'main' into global_cache_dir
sarahyurick Dec 20, 2024
8a22ace
Merge branch 'main' into global_cache_dir
sarahyurick Dec 23, 2024
5e9bef1
Merge branch 'main' into global_cache_dir
sarahyurick Jan 3, 2025
d823a0b
Merge remote-tracking branch 'upstream/main' into global_cache_dir
sarahyurick Jan 21, 2025
0890fb0
re-add get_cache_directory changes
sarahyurick Jan 21, 2025
8fd79fb
create Cache singleton class
sarahyurick Jan 21, 2025
0d7b969
update exact_dedup
sarahyurick Jan 22, 2025
2c1a435
add semdedup functionality with Cache
sarahyurick Jan 22, 2025
f0ff2ce
add semdedup_example script
sarahyurick Jan 22, 2025
a379893
Cache singleton option for fuzzy dedup
sarahyurick Jan 23, 2025
67f609c
run black
sarahyurick Jan 23, 2025
8693177
fix tutorials
sarahyurick Jan 23, 2025
c296cc7
Merge branch 'main' into global_cache_dir
sarahyurick Jan 29, 2025
510347c
Merge branch 'main' into global_cache_dir
sarahyurick Feb 18, 2025
0635ebf
run black
sarahyurick Feb 18, 2025
a229857
import assert_eq
sarahyurick Feb 18, 2025
30ec409
fix semdedup test
sarahyurick Feb 19, 2025
1a63468
Merge branch 'main' into global_cache_dir
sarahyurick Feb 20, 2025
2075588
Merge branch 'main' into global_cache_dir
sarahyurick Feb 25, 2025
a6c5de3
remove repeating param
sarahyurick Feb 25, 2025
b805ce9
Merge remote-tracking branch 'upstream/main' into global_cache_dir
sarahyurick Feb 28, 2025
2ee3547
fix semdedup test
sarahyurick Feb 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config/fuzzy_dedup_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
cache_dir: "./fuzzy_dedup_cache"

# Optional Params below with default values
# profile_dir: null
# id_field: "id"
Expand Down
11 changes: 7 additions & 4 deletions docs/user-guide/semdedup.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,8 @@ Use Individual Components
embedding_creator = EmbeddingCreator(
embedding_model_name_or_path="path/to/pretrained/model",
embedding_batch_size=128,
embedding_output_dir="path/to/output/embeddings",
cache_dir="path/to/output",
embeddings_save_loc="embeddings",
input_column="text",
logger="path/to/log/dir",
)
Expand All @@ -190,7 +191,8 @@ Use Individual Components
id_column="doc_id",
max_iter=100,
n_clusters=50000,
clustering_output_dir="path/to/output/clusters",
cache_dir="path/to/output",
clustering_save_loc="clustering_results",
logger="path/to/log/dir"
)
clustered_dataset = clustering_model(embeddings_dataset)
Expand All @@ -204,12 +206,13 @@ Use Individual Components
# Step 3: Semantic Deduplication
semantic_dedup = SemanticClusterLevelDedup(
n_clusters=50000,
emb_by_clust_dir="path/to/embeddings/by/cluster",
sorted_clusters_dir="path/to/sorted/clusters",
id_column="doc_id",
id_column_type="str",
which_to_keep="hard",
output_dir="path/to/output/deduped",
# cache_dir and clustering_save_loc should match ClusteringModel
cache_dir="path/to/output",
clustering_save_loc="clustering_results",
logger="path/to/log/dir"
)
semantic_dedup.compute_semantic_match_dfs()
Expand Down
6 changes: 3 additions & 3 deletions examples/fuzzy_deduplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def main(args):

filetype = "parquet"

# Fuzzy dup calculation only supports the cuDF/GPU backend
# Fuzzy deduplication only supports the cuDF/GPU backend
backend = "cudf"
assert args.device == "gpu"

Expand Down Expand Up @@ -89,12 +89,12 @@ def main(args):

if duplicates is None:
print("No duplicates found")
print(f"Time taken:{time.time() - t0}s")
print(f"Time taken: {time.time() - t0}s")
return

result = fuzzy_dup.remove(input_dataset, duplicates)
write_to_disk(result, output_dir, output_type=filetype)
print(f"Time taken:{time.time() - t0}s")
print(f"Time taken: {time.time() - t0}s")


def attach_args(
Expand Down
7 changes: 7 additions & 0 deletions examples/semdedup_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,30 @@ def main(args):
log_level=logging.INFO,
stdout=True,
)

st = time.time()

input_files = get_all_files_paths_under(
root=args.input_data_dir,
)

if semdedup_config.num_files > 0:
input_files = input_files[: semdedup_config.num_files]

logger.info(f"Processing {len(input_files)} files")

ddf = read_data(
input_files=input_files,
file_type=args.input_file_type,
add_filename=False,
backend="cudf",
)
dataset = DocumentDataset(ddf)

semdup = SemDedup(semdedup_config, logger=logger)
dedup_ids = semdup(dataset)
print(dedup_ids.df.head())

logger.info(f"Time taken: {time.time() - st}")
client.cancel(client.futures, force=True)
client.close()
Expand Down
48 changes: 48 additions & 0 deletions nemo_curator/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo_curator.utils.file_utils import expand_outdir_and_mkdir


class Cache:
_instance = None
_cache_dir = None

def __new__(cls, cache_dir=None):
if cls._instance is None:
cls._instance = super(Cache, cls).__new__(cls)
if cache_dir is not None:
cls._cache_dir = expand_outdir_and_mkdir(cache_dir)
else:
cls._cache_dir = None
elif cache_dir is not None and cls._cache_dir is None:
cls._cache_dir = expand_outdir_and_mkdir(cache_dir)
return cls._instance

@classmethod
def get_cache_directory(cls) -> str:
"""
Retrieve the cache directory.
"""
return cls._cache_dir

@classmethod
def delete_cache_instance(cls):
"""
Reset the Cache singleton.
"""
if cls._cache_dir is not None:
cls._cache_dir = None

cls._instance = None
126 changes: 81 additions & 45 deletions nemo_curator/modules/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

import yaml

from nemo_curator.cache import Cache


@dataclass
class BaseConfig:
Expand All @@ -31,44 +33,49 @@ def from_yaml(cls, file_path: str):
@dataclass
class FuzzyDuplicatesConfig(BaseConfig):
"""
Configuration for MinHash based fuzzy duplicates detection.
Configuration for MinHash-based fuzzy duplicates detection.

Parameters
----------
seed: Seed for minhash permutations
char_ngrams: Size of Char ngram shingles used in minhash computation
num_buckets: Number of Bands or buckets to use during Locality Sensitive Hashing
hashes_per_bucket: Number of hashes per bucket/band.
cache_dir: If specified, directory to store deduplication intermediates, such as
minhashes, buckets, etc. If None, we check if a cache_dir has been initialized
with Cache().get_cache_directory(). Default is None.
profile_dir: If specified, directory to write Dask profile. Default is None.
id_field: Column in the dataset denoting document ID. Default is "id".
text_field: Column in the dataset denoting document content. Default is "text".
perform_removal: Boolean value to specify whether calling the module should remove
the duplicates from the original dataset, or return the list of IDs denoating
duplicates. Default is False.
seed: Seed for minhash permutations. Default is 42.
char_ngrams: Size of character n-gram shingles used in minhash computation.
Default is 5.
num_buckets: Number of bands or buckets to use during Locality Sensitive Hashing.
Default is 20.
hashes_per_bucket: Number of hashes per bucket/band. Default is 13.
use_64_bit_hash: Whether to use a 32bit or 64bit hash function for minhashing.
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently.
Larger values process larger batches by processing multiple bands
but might lead to memory pressures and related errors.
id_field: Column in the Dataset denoting document ID.
text_field: Column in the Dataset denoting document content.
perform_removal: Boolean value to specify whether calling the module should remove the duplicates from
the original dataset, or return the list of IDs denoting duplicates.
profile_dir: str, Default None
If specified directory to write dask profile
cache_dir: str, Default None
Location to store deduplcation intermediates such as minhashes/buckets etc.
false_positive_check: bool,
Whether to run a check to look for false positives within buckets.
Note: This is a computationally expensive step.
num_anchors: int
Number of documents per bucket to use as reference for computing jaccard
pairs within that bucket to identify false positives.
jaccard_threshold: float
The Jaccard similariy threshold to consider a document a near duplicate
during false positive evaluations.
Default is False.
buckets_per_shuffle: Number of bands/buckets to shuffle concurrently. Larger values
process larger batches by processing multiple bands but might lead to memory
pressures and related errors. Default is 1.
false_positive_check: Whether to run a check to look for false positives within
buckets. Note: This is a computationally expensive step. Default is False.
num_anchors: Number of documents per bucket to use as reference for computing
Jaccard pairs within that bucket to identify false positives. Default is 2.
jaccard_threshold: The Jaccard similariy threshold to consider a document a near
duplicate during false positive evaluations. Default is 0.8.
bucket_mapping_blocksize: Default is 256.
parts_per_worker: Default is 1.
bucket_parts_per_worker: Default is 8.
"""

# General config
cache_dir: str
cache_dir: Optional[str] = None
profile_dir: Optional[str] = None
id_field: str = "id"
text_field: str = "text"
perform_removal: bool = False

# Minhash + LSH Config
# Minhash + LSH config
seed: int = 42
char_ngrams: int = 24
num_buckets: int = 20
Expand All @@ -86,53 +93,72 @@ class FuzzyDuplicatesConfig(BaseConfig):

def __post_init__(self):
self.num_hashes = self.num_buckets * self.hashes_per_bucket

false_positive_defaults = {
"num_anchors": 2,
"jaccard_threshold": 0.8,
"bucket_mapping_blocksize": 256,
"parts_per_worker": 1,
"bucket_parts_per_worker": 8,
}

if self.false_positive_check:
warnings.warn(
"Identifying false positives during the Minhash deduplication is computationally expensive."
" For improved performance consider setting this to False"
"Identifying false positives during Minhash deduplication is "
"computationally expensive. For improved performance consider setting "
"this to False."
)

for arg, default in false_positive_defaults.items():
if getattr(self, arg) is None:
setattr(self, arg, default)

if self.num_anchors <= 0:
raise ValueError("Number of anchors must be greater than 0")
raise ValueError("Number of anchors must be greater than 0.")

if self.num_anchors > 2:
warnings.warn(
"Using a higher number of anchor docs might lead to higher memory footprint and might impact performance",
"Using a higher number of anchor documents might lead to higher memory "
"footprint and might impact performance.",
category=UserWarning,
)

if not 0 <= self.jaccard_threshold <= 1:
raise ValueError("Jaccard Threshold must be between [0,1]")
raise ValueError("Jaccard threshold must be between [0, 1].")

else:
if self.char_ngrams < 20:
warnings.warn(
"Using a small char_ngrams value might lead to a large number (~5%) of false positives during deduplication."
" Using a value of at least 20 for char_ngrams is recommended."
)

unused_false_positive_args = [
arg
for arg in false_positive_defaults.keys()
if getattr(self, arg) is not None
]

if unused_false_positive_args:
warnings.warn(
f"False positive check is disabled. Unused arguments {unused_false_positive_args} will be ignored",
f"False positive check is disabled. Unused arguments {unused_false_positive_args} will be ignored.",
category=UserWarning,
)

if self.cache_dir is None:
raise ValueError(
"Finding fuzzy duplicates requires a cache directory accessible via all workers to store intermediates"
)
if not 1 <= self.buckets_per_shuffle <= self.num_buckets:
raise ValueError("Buckets per shuffle must be between [1, num_buckets]")
raise ValueError("Buckets per shuffle must be between [1, num_buckets].")

if self.cache_dir is None:
cache_dir = Cache().get_cache_directory()
if cache_dir is None:
raise ValueError(
"Finding fuzzy duplicates requires a cache directory accessible via "
"all workers to store intermediates. Please use "
"Cache(cache_dir=...) or FuzzyDuplicatesConfig(cache_dir=...) to "
"set the cache directory."
)
else:
self.cache_dir = cache_dir

if not self.perform_removal:
warnings.warn(
Expand All @@ -146,7 +172,9 @@ class SemDedupConfig(BaseConfig):
Configuration for Semantic Deduplication.

Attributes:
cache_dir (str): Directory to store cache.
cache_dir (Optional[str]): If specified, directory to store cache.
If None, we check if a cache_dir has been initialized with Cache().get_cache_directory().
Default is None.
profile_dir (Optional[str]): If specified, directory to write Dask profile.
Default is None.
num_files (int): Number of files. Default is -1, meaning all files.
Expand Down Expand Up @@ -190,7 +218,7 @@ class SemDedupConfig(BaseConfig):
Default is 0.01.
"""

cache_dir: str
cache_dir: str = None
profile_dir: Optional[str] = None
num_files: int = -1

Expand All @@ -216,17 +244,25 @@ class SemDedupConfig(BaseConfig):
kmeans_with_cos_dist: bool = False
clustering_input_partition_size: str = "2gb"

# Extract dedup config
# SemDedup
eps_thresholds: List[float] = field(default_factory=lambda: [0.01, 0.001])
eps_to_extract: float = 0.01

def __post_init__(self):
if self.cache_dir is None:
raise ValueError(
"Finding sem-dedup requires a cache directory accessible via all workers to store intermediates"
)
cache_dir = Cache().get_cache_directory()
if cache_dir is None:
raise ValueError(
"Finding semantic duplicates requires a cache directory accessible "
"via all workers to store intermediates. Please use "
"Cache(cache_dir=...) or SemDedupConfig(cache_dir=...) to "
"set the cache directory."
)
else:
self.cache_dir = cache_dir

if self.eps_to_extract not in self.eps_thresholds:
raise ValueError(
f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds {self.eps_thresholds}"
f"Epsilon to extract {self.eps_to_extract} must be in eps_thresholds "
f"{self.eps_thresholds}."
)
Loading