Skip to content

Commit 83458db

Browse files
committed
remove _class_register altogether
Signed-off-by: dafnapension <dafnashein@yahoo.com>
1 parent 9f19ed2 commit 83458db

File tree

7 files changed

+29
-101
lines changed

7 files changed

+29
-101
lines changed

.github/workflows/library_tests.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ jobs:
3737
- run: curl -LsSf https://astral.sh/uv/install.sh | sh
3838
- run: uv pip install --upgrade --system torch --index-url https://download.pytorch.org/whl/cpu
3939
- run: uv pip install --system -c constraints.txt -e ".[tests]"
40+
- run: uv pip install --system protobuf
4041
- run: |
4142
pip install --only-binary :all: spacy
4243

docs/conf.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def autodoc_skip_member(app, what, name, obj, would_skip, options):
116116
class_name = obj.__qualname__.split(".")[0]
117117
if (
118118
class_name
119-
and Artifact.is_registered_class_name(class_name)
120119
and class_name != name
121120
):
122121
return True

prepare/metrics/custom_f1.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,4 +433,7 @@ class NERWithoutClassReporting(NER):
433433
global_target=global_target,
434434
)
435435

436-
add_to_catalog(metric, "metrics.ner", overwrite=True)
436+
if __name__ == "__main__" or __name__ == "custom_f1":
437+
# because a class is defined in this module, need to not add_to_catalog just for importing that module in order to retrieve the defined class
438+
# and need to prepare for case when this module is run directly from python (__main__) or, for example, from test_preparation (custom_f1)
439+
add_to_catalog(metric, "metrics.ner", overwrite=True)

src/unitxt/artifact.py

Lines changed: 13 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import difflib
21
import inspect
32
import json
43
import os
@@ -23,7 +22,6 @@
2322
separate_inside_and_outside_square_brackets,
2423
)
2524
from .settings_utils import get_constants, get_settings
26-
from .text_utils import is_camel_case
2725
from .type_utils import isoftype, issubtype
2826
from .utils import (
2927
artifacts_json_cache,
@@ -134,21 +132,11 @@ def maybe_recover_artifacts_structure(obj):
134132
return obj
135133

136134

137-
def get_closest_artifact_type(type):
138-
artifact_type_options = list(Artifact._class_register.keys())
139-
matches = difflib.get_close_matches(type, artifact_type_options)
140-
if matches:
141-
return matches[0] # Return the closest match
142-
return None
143-
144135

145136
class UnrecognizedArtifactTypeError(ValueError):
146137
def __init__(self, type) -> None:
147138
maybe_class = type.split(".")[-1]
148139
message = f"'{type}' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{maybe_class}' or similar) is defined and/or imported anywhere in the code executed."
149-
closest_artifact_type = get_closest_artifact_type(type)
150-
if closest_artifact_type is not None:
151-
message += f"\n\nDid you mean '{closest_artifact_type}'?"
152140
super().__init__(message)
153141

154142

@@ -161,7 +149,7 @@ def __init__(self, dic) -> None:
161149

162150

163151
class Artifact(Dataclass):
164-
_class_register = {}
152+
# _class_register = {}
165153

166154
__type__: str = Field(default=None, final=True, init=False)
167155
__title__: str = NonPositionalField(
@@ -220,7 +208,7 @@ def fix_module_name_if_not_in_path(module):
220208
if file_components[0] == "":
221209
file_components = file_components[1:]
222210
file_components[-1] = file_components[-1].split(".")[0] #omit the .py
223-
if len(module.__package__) == 0:
211+
if not getattr(module, "__package__", None) or len(module.__package__) == 0:
224212
return file_components[-1]
225213
package_components = module.__package__.split(".")
226214
assert all(p_c in file_components for p_c in package_components)
@@ -252,29 +240,9 @@ def get_module_class(cls, artifact_type:str):
252240
return artifact_type.rsplit(".", 1)
253241

254242

255-
@classmethod
256-
def register_class(cls, artifact_class):
257-
assert issubclass(
258-
artifact_class, Artifact
259-
), f"Artifact class must be a subclass of Artifact, got '{artifact_class}'"
260-
assert is_camel_case(
261-
artifact_class.__name__
262-
), f"Artifact class name must be legal camel case, got '{artifact_class.__name__}'"
263-
264-
if cls.is_registered_type(cls.get_artifact_type()):
265-
assert (
266-
str(cls._class_register[cls.get_artifact_type()]) == cls.get_artifact_type()
267-
), f"Artifact class name must be unique, '{cls.get_artifact_type()}' is already registered as {cls._class_register[cls.get_artifact_type()]}. Cannot be overridden by {artifact_class}."
268-
269-
return cls.get_artifact_type()
270-
271-
cls._class_register[cls.get_artifact_type()] = cls.get_artifact_type() # for now, still maintain the registry from qualified to qualified
272-
273-
return cls.get_artifact_type()
274243

275244
def __init_subclass__(cls, **kwargs):
276245
super().__init_subclass__(**kwargs)
277-
cls.register_class(cls)
278246

279247
@classmethod
280248
def is_artifact_file(cls, path):
@@ -284,18 +252,6 @@ def is_artifact_file(cls, path):
284252
d = json.load(f)
285253
return cls.is_artifact_dict(d)
286254

287-
@classmethod
288-
def is_registered_type(cls, type: str):
289-
return type in cls._class_register
290-
291-
@classmethod
292-
def is_registered_class_name(cls, class_name: str):
293-
for k in cls._class_register:
294-
_, artifact_class_name = cls.get_module_class(k)
295-
if artifact_class_name == class_name:
296-
return True
297-
return False
298-
299255
@classmethod
300256
def get_class_from_artifact_type(cls, type:str):
301257
module_path, class_name = cls.get_module_class(type)
@@ -309,27 +265,20 @@ def get_class_from_artifact_type(cls, type:str):
309265
return klass
310266

311267

312-
313268
@classmethod
314269
def _recursive_load(cls, obj):
315270
if isinstance(obj, dict):
316-
new_d = {}
317-
for key, value in obj.items():
318-
new_d[key] = cls._recursive_load(value)
319-
obj = new_d
271+
obj = {key: cls._recursive_load(value) for key, value in obj.items()}
272+
if cls.is_artifact_dict(obj):
273+
try:
274+
artifact_type = obj.pop("__type__")
275+
artifact_class = cls.get_class_from_artifact_type(artifact_type)
276+
obj = artifact_class.process_data_after_load(obj)
277+
return artifact_class(**obj)
278+
except (ImportError, AttributeError) as e:
279+
raise UnrecognizedArtifactTypeError(artifact_type) from e
320280
elif isinstance(obj, list):
321-
obj = [cls._recursive_load(value) for value in obj]
322-
else:
323-
pass
324-
if cls.is_artifact_dict(obj):
325-
cls.verify_artifact_dict(obj)
326-
try:
327-
artifact_type = obj.pop("__type__")
328-
artifact_class = cls.get_class_from_artifact_type(artifact_type)
329-
obj = artifact_class.process_data_after_load(obj)
330-
return artifact_class(**obj)
331-
except (ImportError, AttributeError) as e:
332-
raise UnrecognizedArtifactTypeError(artifact_type) from e
281+
return [cls._recursive_load(value) for value in obj]
333282

334283
return obj
335284

@@ -389,7 +338,7 @@ def verify_data_classification_policy(self):
389338

390339
@final
391340
def __post_init__(self):
392-
self.__type__ = self.register_class(self.__class__)
341+
self.__type__ = self.__class__.get_artifact_type()
393342

394343
for field in fields(self):
395344
if issubtype(

src/unitxt/catalog/cards/safety/airbench2024.json

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
{
2-
"__type__": "task_card",
2+
"__type__": "unitxt.card.TaskCard",
33
"loader": {
4-
"__type__": "multiple_source_loader",
4+
"__type__": "unitxt.loaders.MultipleSourceLoader",
55
"sources": [
66
{
7-
"__type__": "load_hf",
7+
"__type__": "unitxt.loaders.LoadHF",
88
"path": "stanford-crfm/air-bench-2024",
99
"name": "default"
1010
},
1111
{
12-
"__type__": "load_hf",
12+
"__type__": "unitxt.loaders.LoadHF",
1313
"path": "stanford-crfm/air-bench-2024",
1414
"data_files": {
1515
"judge_prompts": "judge_prompt_final.csv"
@@ -22,7 +22,7 @@
2222
},
2323
"preprocess_steps": [
2424
{
25-
"__type__": "select_fields",
25+
"__type__": "unitxt.operators.SelectFields",
2626
"fields": [
2727
"cate-idx",
2828
"judge_prompt"
@@ -32,7 +32,7 @@
3232
]
3333
},
3434
{
35-
"__type__": "join_streams",
35+
"__type__": "unitxt.stream_operators.JoinStreams",
3636
"left_stream": "test",
3737
"right_stream": "judge_prompts",
3838
"how": "inner",
@@ -43,7 +43,7 @@
4343
}
4444
],
4545
"task": {
46-
"__type__": "task",
46+
"__type__": "unitxt.task.Task",
4747
"input_fields": {
4848
"cate-idx": "str",
4949
"l2-name": "str",
@@ -59,10 +59,10 @@
5959
]
6060
},
6161
"templates": {
62-
"__type__": "templates_dict",
62+
"__type__": "unitxt.templates.TemplatesDict",
6363
"items": {
6464
"default": {
65-
"__type__": "input_output_template",
65+
"__type__": "unitxt.templates.InputOutputTemplate",
6666
"input_format": "{prompt}\n",
6767
"output_format": ""
6868
}
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
{
2-
"__type__": "hf_system_format",
2+
"__type__": "unitxt.formats.HFSystemFormat",
33
"model_name": "ibm-granite/granite-3.1-2b-instruct"
44
}

src/unitxt/register.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import importlib
2-
import inspect
31
import os
42
from pathlib import Path
53

6-
from .artifact import Artifact, Catalogs
4+
from .artifact import Catalogs
75
from .catalog import EnvironmentLocalCatalog, GithubCatalog, LocalCatalog
86
from .error_utils import Documentation, UnitxtError, UnitxtWarning
97
from .settings_utils import get_constants, get_settings
@@ -89,27 +87,6 @@ def _reset_env_local_catalogs():
8987
_register_catalog(EnvironmentLocalCatalog(location=path))
9088

9189

92-
def _register_all_artifacts():
93-
dir = os.path.dirname(__file__)
94-
file_name = os.path.basename(__file__)
95-
96-
for file in os.listdir(dir):
97-
if (
98-
file.endswith(".py")
99-
and file not in constants.non_registered_files
100-
and file != file_name
101-
):
102-
module_name = file.replace(".py", "")
103-
104-
module = importlib.import_module("." + module_name, __package__)
105-
106-
for _name, obj in inspect.getmembers(module):
107-
# Make sure the object is a class
108-
if inspect.isclass(obj):
109-
# Make sure the class is a subclass of Artifact (but not Artifact itself)
110-
if issubclass(obj, Artifact) and obj is not Artifact:
111-
Artifact.register_class(obj)
112-
11390

11491
class ProjectArtifactRegisterer(metaclass=Singleton):
11592
def __init__(self):
@@ -118,7 +95,6 @@ def __init__(self):
11895

11996
if not self._registered:
12097
_register_all_catalogs()
121-
_register_all_artifacts()
12298
self._registered = True
12399

124100

0 commit comments

Comments
 (0)