Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor modules to be in a single class and dataframe backend checking #379

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Prev Previous commit
Next Next commit
Refactor modifier
Signed-off-by: Ryan Wolf <rywolf@nvidia.com>
ryantwolf committed Oct 23, 2024
commit e365397676d9aaeef5436f6c745eed9cd32c2a57
54 changes: 54 additions & 0 deletions nemo_curator/modifiers/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import List

from nemo_curator.datasets import DocumentDataset
from nemo_curator.modules.base import Module
from nemo_curator.utils.module_utils import is_batched


class DocumentModifier(Module, ABC):
def __init__(
self,
text_fields: List[str] = ["text"],
meta=(None, str),
input_backend: str = "pandas",
):
super().__init__(input_backend=input_backend)
self.text_fields = text_fields
self.meta = meta

@abstractmethod
def modify_document(self, text):
raise NotImplementedError(
"score_document method must be implemented by subclasses"
)

def call(self, dataset: DocumentDataset) -> DocumentDataset:
text_fields = (
self.text_fields if len(self.text_fields) > 1 else self.text_fields[0]
)

if is_batched(self.modify_document):
dataset.df[text_fields] = dataset.df[text_fields].map_partitions(
self.modify_document, meta=self.meta
)
else:
dataset.df[text_fields] = dataset.df[text_fields].apply(
self.modify_document, meta=self.meta
)

return dataset