1
- import difflib
2
1
import inspect
3
2
import json
4
3
import os
23
22
separate_inside_and_outside_square_brackets ,
24
23
)
25
24
from .settings_utils import get_constants , get_settings
26
- from .text_utils import is_camel_case
27
25
from .type_utils import isoftype , issubtype
28
26
from .utils import (
29
27
artifacts_json_cache ,
@@ -134,21 +132,11 @@ def maybe_recover_artifacts_structure(obj):
134
132
return obj
135
133
136
134
137
- def get_closest_artifact_type (type ):
138
- artifact_type_options = list (Artifact ._class_register .keys ())
139
- matches = difflib .get_close_matches (type , artifact_type_options )
140
- if matches :
141
- return matches [0 ] # Return the closest match
142
- return None
143
-
144
135
145
136
class UnrecognizedArtifactTypeError (ValueError ):
146
137
def __init__ (self , type ) -> None :
147
138
maybe_class = type .split ("." )[- 1 ]
148
139
message = f"'{ type } ' is not a recognized artifact 'type'. Make sure a the class defined this type (Probably called '{ maybe_class } ' or similar) is defined and/or imported anywhere in the code executed."
149
- closest_artifact_type = get_closest_artifact_type (type )
150
- if closest_artifact_type is not None :
151
- message += f"\n \n Did you mean '{ closest_artifact_type } '?"
152
140
super ().__init__ (message )
153
141
154
142
@@ -161,7 +149,7 @@ def __init__(self, dic) -> None:
161
149
162
150
163
151
class Artifact (Dataclass ):
164
- _class_register = {}
152
+ # _class_register = {}
165
153
166
154
__type__ : str = Field (default = None , final = True , init = False )
167
155
__title__ : str = NonPositionalField (
@@ -220,7 +208,7 @@ def fix_module_name_if_not_in_path(module):
220
208
if file_components [0 ] == "" :
221
209
file_components = file_components [1 :]
222
210
file_components [- 1 ] = file_components [- 1 ].split ("." )[0 ] #omit the .py
223
- if len (module .__package__ ) == 0 :
211
+ if not getattr ( module , "__package__" , None ) or len (module .__package__ ) == 0 :
224
212
return file_components [- 1 ]
225
213
package_components = module .__package__ .split ("." )
226
214
assert all (p_c in file_components for p_c in package_components )
@@ -252,29 +240,9 @@ def get_module_class(cls, artifact_type:str):
252
240
return artifact_type .rsplit ("." , 1 )
253
241
254
242
255
- @classmethod
256
- def register_class (cls , artifact_class ):
257
- assert issubclass (
258
- artifact_class , Artifact
259
- ), f"Artifact class must be a subclass of Artifact, got '{ artifact_class } '"
260
- assert is_camel_case (
261
- artifact_class .__name__
262
- ), f"Artifact class name must be legal camel case, got '{ artifact_class .__name__ } '"
263
-
264
- if cls .is_registered_type (cls .get_artifact_type ()):
265
- assert (
266
- str (cls ._class_register [cls .get_artifact_type ()]) == cls .get_artifact_type ()
267
- ), f"Artifact class name must be unique, '{ cls .get_artifact_type ()} ' is already registered as { cls ._class_register [cls .get_artifact_type ()]} . Cannot be overridden by { artifact_class } ."
268
-
269
- return cls .get_artifact_type ()
270
-
271
- cls ._class_register [cls .get_artifact_type ()] = cls .get_artifact_type () # for now, still maintain the registry from qualified to qualified
272
-
273
- return cls .get_artifact_type ()
274
243
275
244
def __init_subclass__ (cls , ** kwargs ):
276
245
super ().__init_subclass__ (** kwargs )
277
- cls .register_class (cls )
278
246
279
247
@classmethod
280
248
def is_artifact_file (cls , path ):
@@ -284,18 +252,6 @@ def is_artifact_file(cls, path):
284
252
d = json .load (f )
285
253
return cls .is_artifact_dict (d )
286
254
287
- @classmethod
288
- def is_registered_type (cls , type : str ):
289
- return type in cls ._class_register
290
-
291
- @classmethod
292
- def is_registered_class_name (cls , class_name : str ):
293
- for k in cls ._class_register :
294
- _ , artifact_class_name = cls .get_module_class (k )
295
- if artifact_class_name == class_name :
296
- return True
297
- return False
298
-
299
255
@classmethod
300
256
def get_class_from_artifact_type (cls , type :str ):
301
257
module_path , class_name = cls .get_module_class (type )
@@ -309,27 +265,20 @@ def get_class_from_artifact_type(cls, type:str):
309
265
return klass
310
266
311
267
312
-
313
268
@classmethod
314
269
def _recursive_load (cls , obj ):
315
270
if isinstance (obj , dict ):
316
- new_d = {}
317
- for key , value in obj .items ():
318
- new_d [key ] = cls ._recursive_load (value )
319
- obj = new_d
271
+ obj = {key : cls ._recursive_load (value ) for key , value in obj .items ()}
272
+ if cls .is_artifact_dict (obj ):
273
+ try :
274
+ artifact_type = obj .pop ("__type__" )
275
+ artifact_class = cls .get_class_from_artifact_type (artifact_type )
276
+ obj = artifact_class .process_data_after_load (obj )
277
+ return artifact_class (** obj )
278
+ except (ImportError , AttributeError ) as e :
279
+ raise UnrecognizedArtifactTypeError (artifact_type ) from e
320
280
elif isinstance (obj , list ):
321
- obj = [cls ._recursive_load (value ) for value in obj ]
322
- else :
323
- pass
324
- if cls .is_artifact_dict (obj ):
325
- cls .verify_artifact_dict (obj )
326
- try :
327
- artifact_type = obj .pop ("__type__" )
328
- artifact_class = cls .get_class_from_artifact_type (artifact_type )
329
- obj = artifact_class .process_data_after_load (obj )
330
- return artifact_class (** obj )
331
- except (ImportError , AttributeError ) as e :
332
- raise UnrecognizedArtifactTypeError (artifact_type ) from e
281
+ return [cls ._recursive_load (value ) for value in obj ]
333
282
334
283
return obj
335
284
@@ -389,7 +338,7 @@ def verify_data_classification_policy(self):
389
338
390
339
@final
391
340
def __post_init__ (self ):
392
- self .__type__ = self .register_class ( self . __class__ )
341
+ self .__type__ = self .__class__ . get_artifact_type ( )
393
342
394
343
for field in fields (self ):
395
344
if issubtype (
0 commit comments