diff --git a/Examples/fitting/README.rst b/Examples/fitting/README.rst deleted file mode 100644 index e0c24c4d..00000000 --- a/Examples/fitting/README.rst +++ /dev/null @@ -1,6 +0,0 @@ -.. _fitting_examples: - -Fitting Examples ------------------------- - -This section gathers examples which correspond to fitting data. diff --git a/Examples/fitting/plot_constraints.py b/Examples/fitting/plot_constraints.py deleted file mode 100644 index b150bc82..00000000 --- a/Examples/fitting/plot_constraints.py +++ /dev/null @@ -1,22 +0,0 @@ -""" -Constraints example -=================== -This example shows the usages of the different constraints. -""" - -from easyscience import Constraints -from easyscience.Objects.ObjectClasses import Parameter - -p1 = Parameter('p1', 1) -constraint = Constraints.NumericConstraint(p1, '<', 5) -p1.user_constraints['c1'] = constraint - -for value in range(4, 7): - p1.value = value - print(f'Set Value: {value}, Parameter Value: {p1}') - -# %% -# To include embedded rST, use a line of >= 20 ``#``'s or ``#%%`` between your -# rST and your code. This separates your example -# into distinct text and code blocks. You can continue writing code below the -# embedded rST text block: diff --git a/docs/src/fitting/constraints.rst b/docs/src/fitting/constraints.rst deleted file mode 100644 index d92c87c2..00000000 --- a/docs/src/fitting/constraints.rst +++ /dev/null @@ -1,75 +0,0 @@ -====================== -Constraints -====================== - -Constraints are a fundamental component in non-trivial fitting operations. They can also be used to affirm the minimum/maximum of a parameter or tie parameters together in a model. - -Anatomy of a constraint ------------------------ - -A constraint is a rule which is applied to a **dependent** variable. This rule can consist of a logical operation, relation to one or more **independent** variables or an arbitrary function. - - -Constraints on Parameters -^^^^^^^^^^^^^^^^^^^^^^^^^ - -:class:`easyscience.Objects.Base.Parameter` has the properties `builtin_constraints` and `user_constraints`. These are dictionaries which correspond to constraints which are intrinsic and extrinsic to the Parameter. This means that on the value change of the Parameter firstly the `builtin_constraints` are evaluated, followed by the `user_constraints`. - - -Constraints on Fitting -^^^^^^^^^^^^^^^^^^^^^^ - -:class:`easyscience.fitting.Fitter` has the ability to evaluate user supplied constraints which effect the value of both fixed and non-fixed parameters. A good example of one such use case would be the ratio between two parameters, where you would create a :class:`easyscience.fitting.Constraints.ObjConstraint`. - -Using constraints ------------------ - -A constraint can be used in one of three ways; Assignment to a parameter, assignment to fitting or on demand. The first two are covered and on demand is shown below. - -.. code-block:: python - - from easyscience.fitting.Constraints import NumericConstraint - from easyscience.Objects.Base import Parameter - # Create an `a < 1` constraint - a = Parameter('a', 0.5) - constraint = NumericConstraint(a, '<=', 1) - # Evaluate the constraint on demand - a.value = 5.0 - constraint() - # A will now equal 1 - -Constraint Reference --------------------- - -.. minigallery:: easyscience.fitting.Constraints.NumericConstraint - :add-heading: Examples using `Constraints` - -Built-in constraints -^^^^^^^^^^^^^^^^^^^^ - -These are the built in constraints which you can use - -.. autoclass:: easyscience.fitting.Constraints.SelfConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.NumericConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.ObjConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.FunctionalConstraint - :members: +enabled - -.. autoclass:: easyscience.fitting.Constraints.MultiObjConstraint - :members: +enabled - -User created constraints -^^^^^^^^^^^^^^^^^^^^^^^^ - -You can also make your own constraints by subclassing the :class:`easyscience.fitting.Constraints.ConstraintBase` class. For this at a minimum the abstract methods ``_parse_operator`` and ``__repr__`` need to be written. - -.. autoclass:: easyscience.fitting.Constraints.ConstraintBase - :members: - :private-members: - :special-members: __repr__ \ No newline at end of file diff --git a/docs/src/index.rst b/docs/src/index.rst index 3683a186..ca99dda5 100644 --- a/docs/src/index.rst +++ b/docs/src/index.rst @@ -56,7 +56,6 @@ Documentation :maxdepth: 3 fitting/introduction - fitting/constraints .. toctree:: :maxdepth: 2 diff --git a/src/easyscience/Constraints.py b/src/easyscience/Constraints.py deleted file mode 100644 index 9628db9e..00000000 --- a/src/easyscience/Constraints.py +++ /dev/null @@ -1,498 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project bool: - """ - Is the current constraint enabled. - - :return: Logical answer to if the constraint is enabled. - """ - return self._enabled - - @enabled.setter - def enabled(self, enabled_value: bool): - """ - Set the enabled state of the constraint. If the new value is the same as the current value only the state is - changed. - - ... note:: If the new value is ``True`` the constraint is also applied after enabling. - - :param enabled_value: New state of the constraint. - :return: None - """ - - if self._enabled == enabled_value: - return - elif enabled_value: - self.get_obj(self.dependent_obj_ids).enabled = False - self() - else: - self.get_obj(self.dependent_obj_ids).enabled = True - self._enabled = enabled_value - - def __call__(self, *args, no_set: bool = False, **kwargs): - """ - Method which applies the constraint - - :return: None if `no_set` is False, float otherwise. - """ - if not self.enabled: - if no_set: - return None - return - independent_objs = None - if isinstance(self.dependent_obj_ids, str): - dependent_obj = self.get_obj(self.dependent_obj_ids) - else: - raise AttributeError - if isinstance(self.independent_obj_ids, str): - independent_objs = self.get_obj(self.independent_obj_ids) - elif isinstance(self.independent_obj_ids, list): - independent_objs = [self.get_obj(obj_id) for obj_id in self.independent_obj_ids] - if independent_objs is not None: - value = self._parse_operator(independent_objs, *args, **kwargs) - else: - value = self._parse_operator(dependent_obj, *args, **kwargs) - - if not no_set: - toggle = False - if not dependent_obj.enabled: - dependent_obj.enabled = True - toggle = True - dependent_obj.value = value - if toggle: - dependent_obj.enabled = False - return value - - @abstractmethod - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - """ - Abstract method which contains the constraint logic - - :param obj: The object/objects which the constraint will use - :return: A numeric result of the constraint logic - """ - - @abstractmethod - def __repr__(self): - pass - - def get_obj(self, key: int) -> V: - """ - Get an EasyScience object from its unique key - - :param key: an EasyScience objects unique key - :return: EasyScience object - """ - return self._global_object.map.get_item_by_key(key) - - -C = TypeVar('C', bound=ConstraintBase) - - -class NumericConstraint(ConstraintBase): - """ - A `NumericConstraint` is a constraint whereby a dependent parameters value is something of an independent parameters - value. I.e. a < 1, a > 5 - """ - - def __init__(self, dependent_obj: V, operator: str, value: Number): - """ - A `NumericConstraint` is a constraint whereby a dependent parameters value is something of an independent - parameters value. I.e. a < 1, a > 5 - - :param dependent_obj: Dependent Parameter - :param operator: Relation to between the parameter and the values. e.g. ``=``, ``<``, ``>`` - :param value: What the parameters value should be compared against. - - :example: - - .. code-block:: python - - from easyscience.fitting.Constraints import NumericConstraint - from easyscience.Objects.Base import Parameter - # Create an `a < 1` constraint - a = Parameter('a', 0.2) - constraint = NumericConstraint(a, '<=', 1) - a.user_constraints['LEQ_1'] = constraint - # This works - a.value = 0.85 - # This triggers the constraint - a.value = 2.0 - # `a` is set to the maximum of the constraint (`a = 1`) - """ - super(NumericConstraint, self).__init__(dependent_obj, operator=operator, value=value) - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - ## TODO Probably needs to be updated when DescriptorArray is implemented - - value = obj.value_no_call_back - - if isinstance(value, list): - value = np.array(value) - self.aeval.symtable['value1'] = value - self.aeval.symtable['value2'] = self.value - try: - self.aeval.eval(f'value3 = value1 {self.operator} value2') - logic = self.aeval.symtable['value3'] - if isinstance(logic, np.ndarray): - value[not logic] = self.aeval.symtable['value2'] - else: - if not logic: - value = self.aeval.symtable['value2'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__} with `value` {self.operator} {self.value}' - - -class SelfConstraint(ConstraintBase): - """ - A `SelfConstraint` is a constraint which tests a logical constraint on a property of itself, similar to a - `NumericConstraint`. i.e. a > a.min. These constraints are usually used in the internal EasyScience logic. - """ - - def __init__(self, dependent_obj: V, operator: str, value: str): - """ - A `SelfConstraint` is a constraint which tests a logical constraint on a property of itself, similar to - a `NumericConstraint`. i.e. a > a.min. - - :param dependent_obj: Dependent Parameter - :param operator: Relation to between the parameter and the values. e.g. ``=``, ``<``, ``>`` - :param value: Name of attribute to be compared against - - :example: - - .. code-block:: python - - from easyscience.fitting.Constraints import SelfConstraint - from easyscience.Objects.Base import Parameter - # Create an `a < a.max` constraint - a = Parameter('a', 0.2, max=1) - constraint = SelfConstraint(a, '<=', 'max') - a.user_constraints['MAX'] = constraint - # This works - a.value = 0.85 - # This triggers the constraint - a.value = 2.0 - # `a` is set to the maximum of the constraint (`a = 1`) - """ - super(SelfConstraint, self).__init__(dependent_obj, operator=operator, value=value) - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - value = obj.value_no_call_back - - self.aeval.symtable['value1'] = value - self.aeval.symtable['value2'] = getattr(obj, self.value) - try: - self.aeval.eval(f'value3 = value1 {self.operator} value2') - logic = self.aeval.symtable['value3'] - if isinstance(logic, np.ndarray): - value[not logic] = self.aeval.symtable['value2'] - else: - if not logic: - value = self.aeval.symtable['value2'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__} with `value` {self.operator} obj.{self.value}' - - -class ObjConstraint(ConstraintBase): - """ - A `ObjConstraint` is a constraint whereby a dependent parameter is something of an independent parameter - value. E.g. a (Dependent Parameter) = 2* b (Independent Parameter) - """ - - def __init__(self, dependent_obj: V, operator: str, independent_obj: V): - """ - A `ObjConstraint` is a constraint whereby a dependent parameter is something of an independent parameter - value. E.g. a (Dependent Parameter) < b (Independent Parameter) - - :param dependent_obj: Dependent Parameter - :param operator: Relation to between the independent parameter and dependent parameter. e.g. ``2 *``, ``1 +`` - :param independent_obj: Independent Parameter - - :example: - - .. code-block:: python - - from easyscience.fitting.Constraints import ObjConstraint - from easyscience.Objects.Base import Parameter - # Create an `a = 2 * b` constraint - a = Parameter('a', 0.2) - b = Parameter('b', 1) - - constraint = ObjConstraint(a, '2*', b) - b.user_constraints['SET_A'] = constraint - b.value = 1 - # This triggers the constraint - a.value # Should equal 2 - - """ - super(ObjConstraint, self).__init__(dependent_obj, independent_obj=independent_obj, operator=operator) - self.external = True - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - value = obj.value_no_call_back - - self.aeval.symtable['value1'] = value - try: - self.aeval.eval(f'value2 = {self.operator} value1') - value = self.aeval.symtable['value2'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__} with `dependent_obj` = {self.operator} `independent_obj`' - - -class MultiObjConstraint(ConstraintBase): - """ - A `MultiObjConstraint` is similar to :class:`EasyScience.fitting.Constraints.ObjConstraint` except that it relates to - multiple independent objects. - """ - - def __init__( - self, - independent_objs: List[V], - operator: List[str], - dependent_obj: V, - value: Number, - ): - """ - A `MultiObjConstraint` is similar to :class:`EasyScience.fitting.Constraints.ObjConstraint` except that it relates - to one or more independent objects. - - E.g. - * a (Dependent Parameter) + b (Independent Parameter) = 1 - * a (Dependent Parameter) + b (Independent Parameter) - 2*c (Independent Parameter) = 0 - - :param independent_objs: List of Independent Parameters - :param operator: List of operators operating on the Independent Parameters - :param dependent_obj: Dependent Parameter - :param value: Value of the expression - - :example: - - **a + b = 1** - - .. code-block:: python - - from easyscience.fitting.Constraints import MultiObjConstraint - from easyscience.Objects.Base import Parameter - # Create an `a + b = 1` constraint - a = Parameter('a', 0.2) - b = Parameter('b', 0.3) - - constraint = MultiObjConstraint([b], ['+'], a, 1) - b.user_constraints['SET_A'] = constraint - b.value = 0.4 - # This triggers the constraint - a.value # Should equal 0.6 - - **a + b - 2c = 0** - - .. code-block:: python - - from easyscience.fitting.Constraints import MultiObjConstraint - from easyscience.Objects.Base import Parameter - # Create an `a + b - 2c = 0` constraint - a = Parameter('a', 0.5) - b = Parameter('b', 0.3) - c = Parameter('c', 0.1) - - constraint = MultiObjConstraint([b, c], ['+', '-2*'], a, 0) - b.user_constraints['SET_A'] = constraint - c.user_constraints['SET_A'] = constraint - b.value = 0.4 - # This triggers the constraint. Or it could be triggered by changing the value of c - a.value # Should equal 0.2 - - .. note:: This constraint is evaluated as ``dependent`` = ``value`` - SUM(``operator_i`` ``independent_i``) - """ - super(MultiObjConstraint, self).__init__( - dependent_obj, - independent_obj=independent_objs, - operator=operator, - value=value, - ) - self.external = True - - def _parse_operator(self, independent_objs: List[V], *args, **kwargs) -> Number: - - in_str = '' - value = None - for idx, obj in enumerate(independent_objs): - self.aeval.symtable['p' + str(self.independent_obj_ids[idx])] = obj.value_no_call_back - - in_str += ' p' + str(self.independent_obj_ids[idx]) - if idx < len(self.operator): - in_str += ' ' + self.operator[idx] - try: - self.aeval.eval(f'final_value = {self.value} - ({in_str})') - value = self.aeval.symtable['final_value'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__}' - - -class FunctionalConstraint(ConstraintBase): - """ - Functional constraints do not depend on other parameters and as such can be more complex. - """ - - def __init__( - self, - dependent_obj: V, - func: Callable, - independent_objs: Optional[List[V]] = None, - ): - """ - Functional constraints do not depend on other parameters and as such can be more complex. - - :param dependent_obj: Dependent Parameter - :param func: Function to be evaluated in the form ``f(value, *args, **kwargs)`` - - :example: - - .. code-block:: python - - import numpy as np - from easyscience.fitting.Constraints import FunctionalConstraint - from easyscience.Objects.Base import Parameter - - a = Parameter('a', 0.2, max=1) - constraint = FunctionalConstraint(a, np.abs) - - a.user_constraints['abs'] = constraint - - # This triggers the constraint - a.value = 0.85 # `a` is set to 0.85 - # This triggers the constraint - a.value = -0.5 # `a` is set to 0.5 - """ - super(FunctionalConstraint, self).__init__(dependent_obj, independent_obj=independent_objs) - self.function = func - if independent_objs is not None: - self.external = True - - def _parse_operator(self, obj: V, *args, **kwargs) -> Number: - - self.aeval.symtable[f'f{id(self.function)}'] = self.function - value_str = f'r_value = f{id(self.function)}(' - if isinstance(obj, list): - for o in obj: - value_str += f'{o.value_no_call_back},' - - value_str = value_str[:-1] - else: - value_str += f'{obj.value_no_call_back}' - - value_str += ')' - try: - self.aeval.eval(value_str) - value = self.aeval.symtable['r_value'] - except Exception as e: - raise e - finally: - self.aeval = Interpreter() - return value - - def __repr__(self) -> str: - return f'{self.__class__.__name__}' - - -def cleanup_constraint(obj_id: str, enabled: bool): - try: - obj = global_object.map.get_item_by_key(obj_id) - obj.enabled = enabled - except ValueError: - if global_object.debug: - print(f'Object with ID {obj_id} has already been deleted') diff --git a/src/easyscience/Objects/ObjectClasses.py b/src/easyscience/Objects/ObjectClasses.py index e6a159ad..376162c5 100644 --- a/src/easyscience/Objects/ObjectClasses.py +++ b/src/easyscience/Objects/ObjectClasses.py @@ -1,16 +1,11 @@ from __future__ import annotations -__author__ = 'github.com/wardsimon' -__version__ = '0.1.0' - # SPDX-FileCopyrightText: 2023 EasyScience contributors # SPDX-License-Identifier: BSD-3-Clause # © 2021-2023 Contributors to the EasyScience project List[C]: - pars = self.get_parameters() - constraints = [] - for par in pars: - con: Dict[str, C] = par.user_constraints - for key in con.keys(): - constraints.append(con[key]) - return constraints - def get_parameters(self) -> List[Parameter]: """ Get all parameter objects as a list. @@ -196,7 +180,7 @@ def get_fit_parameters(self) -> List[Parameter]: if hasattr(item, 'get_fit_parameters'): fit_list = [*fit_list, *item.get_fit_parameters()] elif isinstance(item, Parameter): - if item.enabled and not item.fixed: + if item.independent and not item.fixed: fit_list.append(item) return fit_list diff --git a/src/easyscience/Objects/variable/descriptor_any_type.py b/src/easyscience/Objects/variable/descriptor_any_type.py index 0d117ce2..93745d97 100644 --- a/src/easyscience/Objects/variable/descriptor_any_type.py +++ b/src/easyscience/Objects/variable/descriptor_any_type.py @@ -9,7 +9,7 @@ import numpy as np -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase @@ -62,7 +62,7 @@ def value(self) -> numbers.Number: return self._value @value.setter - @property_stack_deco + @property_stack def value(self, value: Union[list, np.ndarray]) -> None: """ Set the value of self. diff --git a/src/easyscience/Objects/variable/descriptor_array.py b/src/easyscience/Objects/variable/descriptor_array.py index c9b154e5..c7a1d8ca 100644 --- a/src/easyscience/Objects/variable/descriptor_array.py +++ b/src/easyscience/Objects/variable/descriptor_array.py @@ -16,7 +16,7 @@ from scipp import Variable from easyscience.global_object.undo_redo import PropertyStack -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase from .descriptor_number import DescriptorNumber @@ -150,7 +150,7 @@ def value(self) -> numbers.Number: return self._array.values @value.setter - @property_stack_deco + @property_stack def value(self, value: Union[list, np.ndarray]) -> None: """ Set the value of self. Ensures the input is an array and matches the shape of the existing array. @@ -225,7 +225,7 @@ def variance(self) -> np.ndarray: return self._array.variances @variance.setter - @property_stack_deco + @property_stack def variance(self, variance: Union[list, np.ndarray]) -> None: """ Set the variance of self. Ensures the input is an array and matches the shape of the existing values. @@ -259,7 +259,7 @@ def error(self) -> Optional[np.ndarray]: return np.sqrt(self._array.variances) @error.setter - @property_stack_deco + @property_stack def error(self, error: Union[list, np.ndarray]) -> None: """ Set the standard deviation for the parameter, which updates the variances. diff --git a/src/easyscience/Objects/variable/descriptor_base.py b/src/easyscience/Objects/variable/descriptor_base.py index b525d4f1..b80065a2 100644 --- a/src/easyscience/Objects/variable/descriptor_base.py +++ b/src/easyscience/Objects/variable/descriptor_base.py @@ -9,7 +9,7 @@ from typing import Optional from easyscience import global_object -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from easyscience.Objects.core import ComponentSerializer @@ -94,7 +94,7 @@ def name(self) -> str: return self._name @name.setter - @property_stack_deco + @property_stack def name(self, new_name: str) -> None: """ Set the name. @@ -118,7 +118,7 @@ def display_name(self) -> str: return display_name @display_name.setter - @property_stack_deco + @property_stack def display_name(self, name: str) -> None: """ Set the pretty display name. diff --git a/src/easyscience/Objects/variable/descriptor_bool.py b/src/easyscience/Objects/variable/descriptor_bool.py index 768b35b1..23869172 100644 --- a/src/easyscience/Objects/variable/descriptor_bool.py +++ b/src/easyscience/Objects/variable/descriptor_bool.py @@ -3,7 +3,7 @@ from typing import Any from typing import Optional -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase @@ -46,7 +46,7 @@ def value(self) -> bool: return self._bool_value @value.setter - @property_stack_deco + @property_stack def value(self, value: bool) -> None: """ Set the value of self. diff --git a/src/easyscience/Objects/variable/descriptor_number.py b/src/easyscience/Objects/variable/descriptor_number.py index cfba4a44..98631542 100644 --- a/src/easyscience/Objects/variable/descriptor_number.py +++ b/src/easyscience/Objects/variable/descriptor_number.py @@ -13,11 +13,27 @@ from scipp import Variable from easyscience.global_object.undo_redo import PropertyStack -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase +# Why is this a decorator? Because otherwise we would need a flag on the convert_unit method to avoid +# infinite recursion. This is a bit cleaner as it avoids the need for a internal only flag on a user method. +def notify_observers(func): + """ + Decorator to notify observers of a change in the descriptor. + + :param func: Function to be decorated + :return: Decorated function + """ + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + self._notify_observers() + return result + + return wrapper + class DescriptorNumber(DescriptorBase): """ A `Descriptor` for Number values with units. The internal representation is a scipp scalar. @@ -47,6 +63,8 @@ def __init__( param parent: Parent of the descriptor .. note:: Undo/Redo functionality is implemented for the attributes `variance`, `error`, `unit` and `value`. """ + self._observers: List[DescriptorNumber] = [] + if not isinstance(value, numbers.Number) or isinstance(value, bool): raise TypeError(f'{value=} must be a number') if variance is not None: @@ -72,7 +90,8 @@ def __init__( # Call convert_unit during initialization to ensure that the unit has no numbers in it, and to ensure unit consistency. if self.unit is not None: - self.convert_unit(self._base_unit()) + self._convert_unit(self._base_unit()) + @classmethod def from_scipp(cls, name: str, full_value: Variable, **kwargs) -> DescriptorNumber: @@ -90,6 +109,26 @@ def from_scipp(cls, name: str, full_value: Variable, **kwargs) -> DescriptorNumb raise TypeError(f'{full_value=} must be a scipp scalar') return cls(name=name, value=full_value.value, unit=full_value.unit, variance=full_value.variance, **kwargs) + def _attach_observer(self, observer: DescriptorNumber) -> None: + """Attach an observer to the descriptor.""" + self._observers.append(observer) + + def _detach_observer(self, observer: DescriptorNumber) -> None: + """Detach an observer from the descriptor.""" + self._observers.remove(observer) + + def _notify_observers(self, update_id=None) -> None: + """Notify all observers of a change. + + :param update_id: Optional update ID to pass to observers. Used to avoid cyclic depenencies. + + """ + if update_id is None: + self._global_object.update_id_iterator += 1 + update_id = self._global_object.update_id_iterator + for observer in self._observers: + observer._update(update_id=update_id, updating_object=self.unique_name) + @property def full_value(self) -> Variable: """ @@ -115,7 +154,8 @@ def value(self) -> numbers.Number: return self._scalar.value @value.setter - @property_stack_deco + @notify_observers + @property_stack def value(self, value: numbers.Number) -> None: """ Set the value of self. This should be usable for most cases. The full value can be obtained from `obj.full_value`. @@ -154,7 +194,8 @@ def variance(self) -> float: return self._scalar.variance @variance.setter - @property_stack_deco + @notify_observers + @property_stack def variance(self, variance_float: float) -> None: """ Set the variance. @@ -181,7 +222,8 @@ def error(self) -> float: return float(np.sqrt(self._scalar.variance)) @error.setter - @property_stack_deco + @notify_observers + @property_stack def error(self, value: float) -> None: """ Set the standard deviation for the parameter. @@ -198,7 +240,9 @@ def error(self, value: float) -> None: else: self._scalar.variance = None - def convert_unit(self, unit_str: str) -> None: + # When we convert units internally, we dont want to notify observers as this can cause infinite recursion. + # Therefore the convert_unit method is split into two methods, a private internal method and a public method. + def _convert_unit(self, unit_str: str) -> None: """ Convert the value from one unit system to another. @@ -229,6 +273,15 @@ def set_scalar(obj, scalar): # Update the scalar self._scalar = new_scalar + # When the user calls convert_unit, we want to notify observers of the change to propagate the change. + @notify_observers + def convert_unit(self, unit_str: str) -> None: + """ + Convert the value from one unit system to another. + + :param unit_str: New unit in string form + """ + self._convert_unit(unit_str) # Just to get return type right def __copy__(self) -> DescriptorNumber: @@ -267,11 +320,11 @@ def __add__(self, other: Union[DescriptorNumber, numbers.Number]) -> DescriptorN elif type(other) is DescriptorNumber: original_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be added') from None new_value = self.full_value + other.full_value - other.convert_unit(original_unit) + other._convert_unit(original_unit) else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) @@ -297,11 +350,11 @@ def __sub__(self, other: Union[DescriptorNumber, numbers.Number]) -> DescriptorN elif type(other) is DescriptorNumber: original_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be subtracted') from None new_value = self.full_value - other.full_value - other.convert_unit(original_unit) + other._convert_unit(original_unit) else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) @@ -327,7 +380,7 @@ def __mul__(self, other: Union[DescriptorNumber, numbers.Number]) -> DescriptorN else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) - descriptor_number.convert_unit(descriptor_number._base_unit()) + descriptor_number._convert_unit(descriptor_number._base_unit()) descriptor_number.name = descriptor_number.unique_name return descriptor_number @@ -355,7 +408,7 @@ def __truediv__(self, other: Union[DescriptorNumber, numbers.Number]) -> Descrip else: return NotImplemented descriptor_number = DescriptorNumber.from_scipp(name=self.name, full_value=new_value) - descriptor_number.convert_unit(descriptor_number._base_unit()) + descriptor_number._convert_unit(descriptor_number._base_unit()) descriptor_number.name = descriptor_number.unique_name return descriptor_number @@ -415,6 +468,9 @@ def __abs__(self) -> DescriptorNumber: return descriptor_number def _base_unit(self) -> str: + """ + Extract the base unit from the unit string by removing numeric components and scientific notation. + """ string = str(self._scalar.unit) for i, letter in enumerate(string): if letter == 'e': @@ -422,4 +478,4 @@ def _base_unit(self) -> str: return string[i:] elif letter not in ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '+', '-']: return string[i:] - return '' + return '' \ No newline at end of file diff --git a/src/easyscience/Objects/variable/descriptor_str.py b/src/easyscience/Objects/variable/descriptor_str.py index 1abe4e4e..17110166 100644 --- a/src/easyscience/Objects/variable/descriptor_str.py +++ b/src/easyscience/Objects/variable/descriptor_str.py @@ -3,7 +3,7 @@ from typing import Any from typing import Optional -from easyscience.global_object.undo_redo import property_stack_deco +from easyscience.global_object.undo_redo import property_stack from .descriptor_base import DescriptorBase @@ -45,7 +45,7 @@ def value(self) -> str: return self._string @value.setter - @property_stack_deco + @property_stack def value(self, value: str) -> None: """ Set the value of self. diff --git a/src/easyscience/Objects/variable/parameter.py b/src/easyscience/Objects/variable/parameter.py index 03b3c230..d18414e5 100644 --- a/src/easyscience/Objects/variable/parameter.py +++ b/src/easyscience/Objects/variable/parameter.py @@ -6,29 +6,26 @@ import copy import numbers +import re +import warnings import weakref -from collections import namedtuple -from types import MappingProxyType from typing import Any from typing import Dict +from typing import List from typing import Optional -from typing import Tuple from typing import Union import numpy as np import scipp as sc +from asteval import Interpreter from scipp import UnitError from scipp import Variable from easyscience import global_object -from easyscience.Constraints import ConstraintBase -from easyscience.Constraints import SelfConstraint -from easyscience.global_object.undo_redo import property_stack_deco -from easyscience.Utils.Exceptions import CoreSetException +from easyscience.global_object.undo_redo import property_stack from .descriptor_number import DescriptorNumber - -Constraints = namedtuple('Constraints', ['user', 'builtin', 'virtual']) +from .descriptor_number import notify_observers class Parameter(DescriptorNumber): @@ -54,7 +51,6 @@ def __init__( url: Optional[str] = None, display_name: Optional[str] = None, callback: property = property(), - enabled: Optional[bool] = True, parent: Optional[Any] = None, ): """ @@ -68,11 +64,10 @@ def __init__( :param variance: The variance of the value :param min: The minimum value for fitting :param max: The maximum value for fitting - :param fixed: Can the parameter vary while fitting? + :param fixed: If the parameter is free to vary during fitting :param description: A brief summary of what this object is :param url: Lookup url for documentation/information :param display_name: The name of the object as it should be displayed - :param enabled: Can the objects value be set :param parent: The object which is the parent to this one .. note:: @@ -81,19 +76,19 @@ def __init__( if not isinstance(min, numbers.Number): raise TypeError('`min` must be a number') if not isinstance(max, numbers.Number): - raise TypeError('`max` must be a number') + raise TypeError('`max` must be a number') if not isinstance(value, numbers.Number): raise TypeError('`value` must be a number') if value < min: raise ValueError(f'{value=} can not be less than {min=}') if value > max: raise ValueError(f'{value=} can not be greater than {max=}') - if np.isclose(min, max, rtol=1e-9, atol=0.0): raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') if not isinstance(fixed, bool): raise TypeError('`fixed` must be either True or False') - + self._independent = True + self._observers: List[DescriptorNumber] = [] self._fixed = fixed # For fitting, but must be initialized before super().__init__ self._min = sc.scalar(float(min), unit=unit) self._max = sc.scalar(float(max), unit=unit) @@ -115,14 +110,177 @@ def __init__( weakref.finalize(self, self._callback.fdel) # Create additional fitting elements - self._enabled = enabled self._initial_scalar = copy.deepcopy(self._scalar) - builtin_constraint = { - # Last argument in constructor is the name of the property holding the value of the constraint - 'min': SelfConstraint(self, '>=', 'min'), - 'max': SelfConstraint(self, '<=', 'max'), - } - self._constraints = Constraints(builtin=builtin_constraint, user={}, virtual={}) + + @classmethod + def from_dependency(cls, name: str, dependency_expression: str, dependency_map: Optional[dict] = None, **kwargs) -> Parameter: # noqa: E501 + """ + Create a dependent Parameter directly from a dependency expression. + + :param name: The name of the parameter + :param dependency_expression: The dependency expression to evaluate. This should be a string which can be evaluated by the ASTEval interpreter. + :param dependency_map: A dictionary of dependency expression symbol name and dependency object pairs. This is inserted into the asteval interpreter to resolve dependencies. + :param kwargs: Additional keyword arguments to pass to the Parameter constructor. + :return: A new dependent Parameter object. + """ # noqa: E501 + parameter = cls(name=name, value=0.0, unit='', variance=0.0, min=-np.inf, max=np.inf, **kwargs) + parameter.make_dependent_on(dependency_expression=dependency_expression, dependency_map=dependency_map) + return parameter + + + def _update(self, update_id: int, updating_object: str) -> None: + """ + Update the parameter. This is called by the DescriptorNumbers/Parameters who have this Parameter as a dependency. + + :param update_id: The id of the update. This is used to avoid cyclic dependencies. + :param updating_object: The unique_name of the object which is updating this parameter. + + """ + if not self._independent: + # Check if this parameter has already been updated by the updating object with this update id + if updating_object not in self._dependency_updates: + self._dependency_updates[updating_object] = 0 + if self._dependency_updates[updating_object] == update_id: + warnings.warn('Warning: Cyclic dependency detected!\n' + + f'This parameter, {self.unique_name}, has already been updated by {updating_object} during this update.\n' + # noqa: E501 + 'This update will be ignored. Please check your dependencies.') + else: + # Update the value of the parameter using the dependency interpreter + temporary_parameter = self._dependency_interpreter(self._clean_dependency_string) + self._scalar.value = temporary_parameter.value + self._scalar.unit = temporary_parameter.unit + self._scalar.variance = temporary_parameter.variance + self._min.value = temporary_parameter.min + self._max.value = temporary_parameter.max + self._notify_observers(update_id=update_id) + else: + warnings.warn('This parameter is not dependent. It cannot be updated.') + + def make_dependent_on(self, dependency_expression: str, dependency_map: Optional[dict] = None) -> None: + """ + Make this parameter dependent on another parameter. This will overwrite the current value, unit, variance, min and max. + + How to use the dependency map: + If a parameter c has a dependency expression of 'a + b', where a and b are parameters belonging to the model class, + then the dependency map needs to have the form {'a': model.a, 'b': model.b}, where model is the model class. + I.e. the values are the actual objects, whereas the keys are how they are represented in the dependency expression. + + The dependency map is not needed if the dependency expression uses the unique names of the parameters. + Unique names in dependency expressions are defined by quotes, e.g. 'Parameter_0' or "Parameter_0" depending on the quotes used for the expression. + + :param dependency_expression: The dependency expression to evaluate. This should be a string which can be evaluated by a python interpreter. + :param dependency_map: A dictionary of dependency expression symbol name and dependency object pairs. This is inserted into the asteval interpreter to resolve dependencies. + """ # noqa: E501 + if not isinstance(dependency_expression, str): + raise TypeError('`dependency_expression` must be a string representing a valid dependency expression.') + if not (isinstance(dependency_map, dict) or dependency_map is None): + raise TypeError('`dependency_map` must be a dictionary of dependencies and their corresponding names in the dependecy expression.') # noqa: E501 + for key, value in dependency_map.items(): + if not isinstance(key, str): + raise TypeError('`dependency_map` keys must be strings representing the names of the dependencies in the dependency expression.') # noqa: E501 + if not isinstance(value, DescriptorNumber): + raise TypeError(f'`dependency_map` values must be DescriptorNumbers or Parameters. Got {type(value)} for {key}.') # noqa: E501 + + # If we're overwriting the dependency + if not self._independent: + for old_dependency in self._dependency_map.values(): + old_dependency._detach_observer(self) + + self._dependency_string = dependency_expression + self._dependency_map = dependency_map if dependency_map is not None else {} + self._dependency_interpreter = Interpreter(minimal=True) + self._dependency_interpreter.config['if'] = True # allows logical statements in the dependency expression + self._dependency_updates = {} # Used to track update ids to avoid cyclic dependencies + + self._process_dependency_unique_names(self._dependency_string) + for key, value in self._dependency_map.items(): + self._dependency_interpreter.symtable[key] = value + self._dependency_interpreter.readonly_symbols.add(key) # Dont allow overwriting of the dependencies in the dependency expression # noqa: E501 + value._attach_observer(self) + try: + dependency_result = self._dependency_interpreter.eval(self._clean_dependency_string, raise_errors=True) + except NameError as message: + raise NameError('\nUnknown name encountered in dependecy expression:'+ + '\n'+'\n'.join(str(message).split("\n")[1:])+ + '\nPlease check your expression or add the name to the `dependency_map`') from None + except Exception as message: + raise Exception('\nError encountered in dependecy expression:'+ + '\n'+'\n'.join(str(message).split("\n")[1:])+ + '\nPlease check your expression') from None + if not isinstance(dependency_result, DescriptorNumber): + raise TypeError(f'The dependency expression: "{self._clean_dependency_string}" returned a {type(dependency_result)}, it should return a Parameter or DescriptorNumber.') # noqa: E501 + self._scalar.value = dependency_result.value + self._scalar.unit = dependency_result.unit + self._scalar.variance = dependency_result.variance + self._min.value = dependency_result.min if isinstance(dependency_result, Parameter) else dependency_result.value + self._max.value = dependency_result.max if isinstance(dependency_result, Parameter) else dependency_result.value + self._independent = False + self._fixed = False + self._notify_observers() + + def make_independent(self) -> None: + """ + Make this parameter independent. + This will remove the dependency expression, the dependency map and the dependency interpreter. + + :return: None + """ + if not self._independent: + for dependency in self._dependency_map.values(): + dependency._detach_observer(self) + self._independent = True + del self._dependency_map + del self._dependency_updates + del self._dependency_interpreter + del self._dependency_string + del self._clean_dependency_string + else: + raise AttributeError('This parameter is already independent.') + + @property + def independent(self) -> bool: + """ + Is the parameter independent? + + :return: True = independent, False = dependent + """ + return self._independent + + @independent.setter + def independent(self, value: bool) -> None: + raise AttributeError('This property is read-only. Use `make_independent` and `make_dependent_on` to change the state of the parameter.') # noqa: E501 + + @property + def dependency_expression(self) -> str: + """ + Get the dependency expression of this parameter. + + :return: The dependency expression of this parameter. + """ + if not self._independent: + return self._dependency_string + else: + raise AttributeError('This parameter is independent. It has no dependency expression.') + + @dependency_expression.setter + def depedency_expression(self, new_expression: str) -> None: + raise AttributeError('Dependency expression is read-only. Use `make_dependent_on` to change the dependency expression.') + + @property + def dependency_map(self) -> Dict[str, DescriptorNumber]: + """ + Get the dependency map of this parameter. + + :return: The dependency map of this parameter. + """ + if not self._independent: + return self._dependency_map + else: + raise AttributeError('This parameter is independent. It has no dependency map.') + + @dependency_map.setter + def dependency_map(self, new_map: Dict[str, DescriptorNumber]) -> None: + raise AttributeError('Dependency map is read-only. Use `make_dependent_on` to change the dependency map.') @property def value_no_call_back(self) -> numbers.Number: @@ -167,57 +325,79 @@ def value(self) -> numbers.Number: return self._scalar.value @value.setter - @property_stack_deco + @property_stack def value(self, value: numbers.Number) -> None: """ Set the value of self. This only updates the value of the scipp scalar. :param value: New value of self """ - if not self.enabled: - if global_object.debug: - raise CoreSetException(f'{str(self)} is not enabled.') - return + if self._independent: + if not isinstance(value, numbers.Number): + raise TypeError(f'{value=} must be a number') + + value = float(value) + if value < self._min.value: + value = self._min.value + if value > self._max.value: + value = self._max.value - if not isinstance(value, numbers.Number) or isinstance(value, bool): - raise TypeError(f'{value=} must be a number') + self._scalar.value = value - # Need to set the value for constraints to be functional - self._scalar.value = float(value) - # if self._callback.fset is not None: - # self._callback.fset(self._scalar.value) + if self._callback.fset is not None: + self._callback.fset(self._scalar.value) - # Deals with min/max - value = self._constraint_runner(self.builtin_constraints, self._scalar.value) + # Notify observers of the change + self._notify_observers() + else: + raise AttributeError("This is a dependent parameter, its value cannot be set directly.") - # Deals with user constraints - # Changes should not be registrered in the undo/redo stack - stack_state = global_object.stack.enabled - if stack_state: - global_object.stack.force_state(False) - try: - value = self._constraint_runner(self.user_constraints, value) - finally: - global_object.stack.force_state(stack_state) + @DescriptorNumber.variance.setter + def variance(self, variance_float: float) -> None: + """ + Set the variance. - value = self._constraint_runner(self._constraints.virtual, value) + :param variance_float: Variance as a float + """ + if self._independent: + DescriptorNumber.variance.fset(self, variance_float) + else: + raise AttributeError("This is a dependent parameter, its variance cannot be set directly.") - self._scalar.value = float(value) - if self._callback.fset is not None: - self._callback.fset(self._scalar.value) + @DescriptorNumber.error.setter + def error(self, value: float) -> None: + """ + Set the standard deviation for the parameter. - def convert_unit(self, unit_str: str) -> None: + :param value: New error value + """ + if self._independent: + DescriptorNumber.error.fset(self, value) + else: + raise AttributeError("This is a dependent parameter, its error cannot be set directly.") + + def _convert_unit(self, unit_str: str) -> None: """ Perform unit conversion. The value, max and min can change on unit change. :param new_unit: new unit :return: None """ - super().convert_unit(unit_str) + super()._convert_unit(unit_str=unit_str) new_unit = sc.Unit(unit_str) # unit_str is tested in super method self._min = self._min.to(unit=new_unit) self._max = self._max.to(unit=new_unit) + @notify_observers + def convert_unit(self, unit_str: str) -> None: + """ + Perform unit conversion. The value, max and min can change on unit change. + + :param new_unit: new unit + :return: None + """ + self._convert_unit(unit_str) + @property def min(self) -> numbers.Number: """ @@ -228,7 +408,7 @@ def min(self) -> numbers.Number: return self._min.value @min.setter - @property_stack_deco + @property_stack def min(self, min_value: numbers.Number) -> None: """ Set the minimum value for fitting. @@ -237,14 +417,18 @@ def min(self, min_value: numbers.Number) -> None: :param min_value: new minimum value :return: None """ - if not isinstance(min_value, numbers.Number): - raise TypeError('`min` must be a number') - if np.isclose(min_value, self._max.value, rtol=1e-9, atol=0.0): - raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') - if min_value <= self.value: - self._min.value = min_value + if self._independent: + if not isinstance(min_value, numbers.Number): + raise TypeError('`min` must be a number') + if np.isclose(min_value, self._max.value, rtol=1e-9, atol=0.0): + raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') + if min_value <= self.value: + self._min.value = min_value + else: + raise ValueError(f'The current value ({self.value}) is smaller than the desired min value ({min_value}).') + self._notify_observers() else: - raise ValueError(f'The current value ({self.value}) is smaller than the desired min value ({min_value}).') + raise AttributeError("This is a dependent parameter, its minimum value cannot be set directly.") @property def max(self) -> numbers.Number: @@ -256,7 +440,7 @@ def max(self) -> numbers.Number: return self._max.value @max.setter - @property_stack_deco + @property_stack def max(self, max_value: numbers.Number) -> None: """ Get the maximum value for fitting. @@ -265,14 +449,18 @@ def max(self, max_value: numbers.Number) -> None: :param max_value: new maximum value :return: None """ - if not isinstance(max_value, numbers.Number): - raise TypeError('`max` must be a number') - if np.isclose(max_value, self._min.value, rtol=1e-9, atol=0.0): - raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') - if max_value >= self.value: - self._max.value = max_value + if self._independent: + if not isinstance(max_value, numbers.Number): + raise TypeError('`max` must be a number') + if np.isclose(max_value, self._min.value, rtol=1e-9, atol=0.0): + raise ValueError('The min and max bounds cannot be identical. Please use fixed=True instead to fix the value.') + if max_value >= self.value: + self._max.value = max_value + else: + raise ValueError(f'The current value ({self.value}) is greater than the desired max value ({max_value}).') + self._notify_observers() else: - raise ValueError(f'The current value ({self.value}) is greater than the desired max value ({max_value}).') + raise AttributeError("This is a dependent parameter, its maximum value cannot be set directly.") @property def fixed(self) -> bool: @@ -284,7 +472,7 @@ def fixed(self) -> bool: return self._fixed @fixed.setter - @property_stack_deco + @property_stack def fixed(self, fixed: bool) -> None: """ Change the parameter vary while fitting state. @@ -292,17 +480,17 @@ def fixed(self, fixed: bool) -> None: :param fixed: True = fixed, False = can vary """ - if not self.enabled: - if global_object.stack.enabled: - # Remove the recorded change from the stack - global_object.stack.pop() - if global_object.debug: - raise CoreSetException(f'{str(self)} is not enabled.') - return if not isinstance(fixed, bool): raise ValueError(f'{fixed=} must be a boolean. Got {type(fixed)}') - self._fixed = fixed + if self._independent: + self._fixed = fixed + else: + if self._global_object.stack.enabled: + # Remove the recorded change from the stack + global_object.stack.pop() + raise AttributeError("This is a dependent parameter, dependent parameters cannot be fixed.") + # Is this alias really needed? @property def free(self) -> bool: return not self.fixed @@ -311,112 +499,30 @@ def free(self) -> bool: def free(self, value: bool) -> None: self.fixed = not value - @property - def bounds(self) -> Tuple[numbers.Number, numbers.Number]: - """ - Get the bounds of the parameter. - - :return: Tuple of the parameters minimum and maximum values + def _process_dependency_unique_names(self, dependency_expression: str): """ - return self.min, self.max - @bounds.setter - def bounds(self, new_bound: Tuple[numbers.Number, numbers.Number]) -> None: - """ - Set the bounds of the parameter. *This will also enable the parameter*. + Add the unique names of the parameters to the ASTEval interpreter. This is used to evaluate the dependency expression. - :param new_bound: New bounds. This should be a tuple of (min, max). + :param dependency_expression: The dependency expression to be evaluated """ - old_min = self.min - old_max = self.max - new_min, new_max = new_bound - - # Begin macro operation for undo/redo - close_macro = False - if self._global_object.stack.enabled: - self._global_object.stack.beginMacro('Setting bounds') - close_macro = True + # Get the unique_names from the expression string regardless of the quotes used + inputted_unique_names = re.findall("(\'.+?\')", dependency_expression) + inputted_unique_names += re.findall('(\".+?\")', dependency_expression) - try: - # Update bounds - self.min = new_min - self.max = new_max - except ValueError: - # Rollback on failure - self.min = old_min - self.max = old_max - if close_macro: - self._global_object.stack.endMacro() - raise ValueError(f'Current parameter value: {self._scalar.value} must be within {new_bound=}') - - # Enable the parameter if needed - if not self.enabled: - self.enabled = True - # Free parameter if needed - if self.fixed: - self.fixed = False - - # End macro operation - if close_macro: - self._global_object.stack.endMacro() - - @property - def builtin_constraints(self) -> Dict[str, SelfConstraint]: - """ - Get the built in constrains of the object. Typically these are the min/max - - :return: Dictionary of constraints which are built into the system - """ - return MappingProxyType(self._constraints.builtin) - - @property - def user_constraints(self) -> Dict[str, ConstraintBase]: - """ - Get the user specified constrains of the object. - - :return: Dictionary of constraints which are user supplied - """ - return self._constraints.user - - @user_constraints.setter - def user_constraints(self, constraints_dict: Dict[str, ConstraintBase]) -> None: - self._constraints.user = constraints_dict - - def _constraint_runner( - self, - this_constraint_type, - value: numbers.Number, - ) -> float: - for constraint in this_constraint_type.values(): - if constraint.external: - constraint() - continue - - constained_value = constraint(no_set=True) - if constained_value != value: - if global_object.debug: - print(f'Constraint `{constraint}` has been applied') - self._scalar.value = constained_value - value = constained_value - return value - - @property - def enabled(self) -> bool: - """ - Logical property to see if the objects value can be directly set. - - :return: Can the objects value be set - """ - return self._enabled - - @enabled.setter - @property_stack_deco - def enabled(self, value: bool) -> None: - """ - Enable and disable the direct setting of an objects value field. - - :param value: True - objects value can be set, False - the opposite - """ - self._enabled = value + clean_dependency_string = dependency_expression + existing_unique_names = self._global_object.map.vertices() + # Add the unique names of the parameters to the ASTEVAL interpreter + for name in inputted_unique_names: + stripped_name = name.strip("'\"") + if stripped_name not in existing_unique_names: + raise ValueError(f'A Parameter with unique_name {stripped_name} does not exist. Please check your dependency expression.') # noqa: E501 + dependent_parameter = self._global_object.map.get_item_by_key(stripped_name) + if isinstance(dependent_parameter, DescriptorNumber): + self._dependency_map['__'+stripped_name+'__'] = dependent_parameter + clean_dependency_string = clean_dependency_string.replace(name, '__'+stripped_name+'__') + else: + raise ValueError(f'The object with unique_name {stripped_name} is not a Parameter or DescriptorNumber. Please check your dependency expression.') # noqa: E501 + self._clean_dependency_string = clean_dependency_string def __copy__(self) -> Parameter: new_obj = super().__copy__() @@ -450,13 +556,13 @@ def __add__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here other_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be added') from None new_full_value = self.full_value + other.full_value min_value = self.min + other.min if isinstance(other, Parameter) else self.min + other.value max_value = self.max + other.max if isinstance(other, Parameter) else self.max + other.value - other.convert_unit(other_unit) + other._convert_unit(other_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -473,13 +579,13 @@ def __radd__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parameter: elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here original_unit = self.unit try: - self.convert_unit(other.unit) + self._convert_unit(other.unit) except UnitError: raise UnitError(f'Values with units {other.unit} and {self.unit} cannot be added') from None new_full_value = self.full_value + other.full_value min_value = self.min + other.value max_value = self.max + other.value - self.convert_unit(original_unit) + self._convert_unit(original_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -496,7 +602,7 @@ def __sub__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here other_unit = other.unit try: - other.convert_unit(self.unit) + other._convert_unit(self.unit) except UnitError: raise UnitError(f'Values with units {self.unit} and {other.unit} cannot be subtracted') from None new_full_value = self.full_value - other.full_value @@ -506,7 +612,7 @@ def __sub__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> else: min_value = self.min - other.value max_value = self.max - other.value - other.convert_unit(other_unit) + other._convert_unit(other_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -523,13 +629,13 @@ def __rsub__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parameter: elif isinstance(other, DescriptorNumber): # Parameter inherits from DescriptorNumber and is also handled here original_unit = self.unit try: - self.convert_unit(other.unit) + self._convert_unit(other.unit) except UnitError: raise UnitError(f'Values with units {other.unit} and {self.unit} cannot be subtracted') from None new_full_value = other.full_value - self.full_value min_value = other.value - self.max max_value = other.value - self.min - self.convert_unit(original_unit) + self._convert_unit(original_unit) else: return NotImplemented parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) @@ -573,7 +679,7 @@ def __mul__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) -> min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name return parameter @@ -597,7 +703,7 @@ def __rmul__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parameter: min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name return parameter @@ -639,7 +745,7 @@ def __truediv__(self, other: Union[DescriptorNumber, Parameter, numbers.Number]) min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name return parameter @@ -680,7 +786,7 @@ def __rtruediv__(self, other: Union[DescriptorNumber, numbers.Number]) -> Parame min_value = min(combinations) max_value = max(combinations) parameter = Parameter.from_scipp(name=self.name, full_value=new_full_value, min=min_value, max=max_value) - parameter.convert_unit(parameter._base_unit()) + parameter._convert_unit(parameter._base_unit()) parameter.name = parameter.unique_name self.value = original_self return parameter diff --git a/src/easyscience/fitting/fitter.py b/src/easyscience/fitting/fitter.py index daea7782..53007879 100644 --- a/src/easyscience/fitting/fitter.py +++ b/src/easyscience/fitting/fitter.py @@ -34,15 +34,6 @@ def __init__(self, fit_object, fit_function: Callable): self._enum_current_minimizer: AvailableMinimizers = None # set in _update_minimizer self._update_minimizer(DEFAULT_MINIMIZER) - def fit_constraints(self) -> list: - return self._minimizer.fit_constraints() - - def add_fit_constraint(self, constraint) -> None: - self._minimizer.add_fit_constraint(constraint) - - def remove_fit_constraint(self, index: int) -> None: - self._minimizer.remove_fit_constraint(index) - def make_model(self, pars=None) -> Callable: return self._minimizer.make_model(pars) @@ -84,9 +75,7 @@ def switch_minimizer(self, minimizer_enum: Union[AvailableMinimizers, str]) -> N print(f'minimizer should be set with enum {minimizer_enum}') minimizer_enum = from_string_to_enum(minimizer_enum) - constraints = self._minimizer.fit_constraints() self._update_minimizer(minimizer_enum) - self._minimizer.set_fit_constraint(constraints) def _update_minimizer(self, minimizer_enum: AvailableMinimizers) -> None: self._minimizer = factory(minimizer_enum=minimizer_enum, fit_object=self._fit_object, fit_function=self.fit_function) @@ -235,11 +224,7 @@ def inner_fit_callable( # Fit fit_fun_org = self._fit_function fit_fun_wrap = self._fit_function_wrapper(x_new, flatten=True) # This should be wrapped. - - # We change the fit function, so have to reset constraints - constraints = self._minimizer.fit_constraints() self.fit_function = fit_fun_wrap - self._minimizer.set_fit_constraint(constraints) f_res = self._minimizer.fit( x_fit, y_new, @@ -251,9 +236,8 @@ def inner_fit_callable( # Postcompute fit_result = self._post_compute_reshaping(f_res, x, y) - # Reset the function and constrains + # Reset the function self.fit_function = fit_fun_org - self._minimizer.set_fit_constraint(constraints) return fit_result return inner_fit_callable diff --git a/src/easyscience/fitting/minimizers/minimizer_base.py b/src/easyscience/fitting/minimizers/minimizer_base.py index 02130a6e..511057f5 100644 --- a/src/easyscience/fitting/minimizers/minimizer_base.py +++ b/src/easyscience/fitting/minimizers/minimizer_base.py @@ -16,8 +16,6 @@ import numpy as np -from easyscience.Constraints import ObjConstraint - # causes circular import when Parameter is imported # from easyscience.Objects.ObjectClasses import BaseObj from easyscience.Objects.variable import Parameter @@ -52,11 +50,6 @@ def __init__( self._cached_pars_vals: Dict[str, Tuple[float]] = {} self._cached_model = None self._fit_function = None - self._constraints = [] - - @property - def all_constraints(self) -> List[ObjConstraint]: - return [*self._constraints, *self._object._constraints] @property def enum(self) -> AvailableMinimizers: @@ -66,18 +59,6 @@ def enum(self) -> AvailableMinimizers: def name(self) -> str: return self._minimizer_enum.name - def fit_constraints(self) -> List[ObjConstraint]: - return self._constraints - - def set_fit_constraint(self, constraints: List[ObjConstraint]): - self._constraints = constraints - - def add_fit_constraint(self, constraint: ObjConstraint): - self._constraints.append(constraint) - - def remove_fit_constraint(self, index: int) -> None: - del self._constraints[index] - @abstractmethod def fit( self, @@ -237,8 +218,6 @@ def _fit_function(x: np.ndarray, **kwargs): # Since we are calling the parameter fset will be called. # TODO Pre processing here - for constraint in self.fit_constraints(): - constraint() return_data = func(x) # TODO Loading or manipulating data here return return_data diff --git a/src/easyscience/global_object/global_object.py b/src/easyscience/global_object/global_object.py index c78db4ef..dd188db2 100644 --- a/src/easyscience/global_object/global_object.py +++ b/src/easyscience/global_object/global_object.py @@ -36,6 +36,8 @@ def __init__(self): # Map. This is the conduit database between all global object species self.map: Map = self.__map + self.update_id_iterator = 0 + def instantiate_stack(self): """ The undo/redo stack references the collective. Hence it has to be imported diff --git a/src/easyscience/global_object/undo_redo.py b/src/easyscience/global_object/undo_redo.py index 02b30020..e421bcf1 100644 --- a/src/easyscience/global_object/undo_redo.py +++ b/src/easyscience/global_object/undo_redo.py @@ -428,18 +428,18 @@ def redo(self) -> NoReturn: self._parent.data = self._new_value -def property_stack_deco(arg: Union[str, Callable], begin_macro: bool = False) -> Callable: +def property_stack(arg: Union[str, Callable], begin_macro: bool = False) -> Callable: """ Decorate a `property` setter with undo/redo functionality This decorator can be used as: - @property_stack_deco + @property_stack def func() .... or - @property_stack_deco("This is the undo/redo text) + @property_stack("This is the undo/redo text) def func() .... diff --git a/tests/integration_tests/Fitting/test_fitter.py b/tests/integration_tests/Fitting/test_fitter.py index 19e0f876..0706d3bc 100644 --- a/tests/integration_tests/Fitting/test_fitter.py +++ b/tests/integration_tests/Fitting/test_fitter.py @@ -2,13 +2,9 @@ # SPDX-License-Identifier: BSD-3-Clause # © 2021-2023 Contributors to the EasyScience project None: # When minimizer._original_fit_function = MagicMock(return_value='fit_function_result') - mock_fit_constraint = MagicMock() - minimizer.fit_constraints = MagicMock(return_value=[mock_fit_constraint]) - minimizer._object = MagicMock() mock_parm_1 = MagicMock(Parameter) mock_parm_1.unique_name = 'mock_parm_1' @@ -148,7 +144,6 @@ def test_generate_fit_function(self, minimizer: MinimizerBase) -> None: # Expect assert 'fit_function_result' == fit_function_result - mock_fit_constraint.assert_called_once_with() minimizer._original_fit_function.assert_called_once_with([10.0]) assert minimizer._cached_pars['mock_parm_1'] == mock_parm_1 assert minimizer._cached_pars['mock_parm_2'] == mock_parm_2 diff --git a/tests/unit_tests/Fitting/minimizers/test_minimizer_dfo.py b/tests/unit_tests/Fitting/minimizers/test_minimizer_dfo.py index 8c39b8a5..1cd14cd5 100644 --- a/tests/unit_tests/Fitting/minimizers/test_minimizer_dfo.py +++ b/tests/unit_tests/Fitting/minimizers/test_minimizer_dfo.py @@ -72,9 +72,6 @@ def test_generate_fit_function(self, minimizer: DFO) -> None: # When minimizer._original_fit_function = MagicMock(return_value='fit_function_result') - mock_fit_constraint = MagicMock() - minimizer.fit_constraints = MagicMock(return_value=[mock_fit_constraint]) - minimizer._object = MagicMock() mock_parm_1 = MagicMock() mock_parm_1.unique_name = 'mock_parm_1' @@ -92,7 +89,6 @@ def test_generate_fit_function(self, minimizer: DFO) -> None: # Expect assert 'fit_function_result' == fit_function_result - mock_fit_constraint.assert_called_once_with() minimizer._original_fit_function.assert_called_once_with([10.0]) assert minimizer._cached_pars['mock_parm_1'] == mock_parm_1 assert minimizer._cached_pars['mock_parm_2'] == mock_parm_2 diff --git a/tests/unit_tests/Fitting/test_constraints.py b/tests/unit_tests/Fitting/test_constraints.py deleted file mode 100644 index 47a5792b..00000000 --- a/tests/unit_tests/Fitting/test_constraints.py +++ /dev/null @@ -1,134 +0,0 @@ -__author__ = "github.com/wardsimon" -__version__ = "0.1.0" - -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project Tuple[List[Parameter], List[int]]: - mock_callback = MagicMock() - mock_callback.fget = MagicMock(return_value=-10) - return [Parameter("a", 1, callback=mock_callback), Parameter("b", 2, callback=mock_callback)], [1, 2] - - -@pytest.fixture -def threePars(twoPars) -> Tuple[List[Parameter], List[int]]: - ps, vs = twoPars - ps.append(Parameter("c", 3)) - vs.append(3) - return ps, vs - - -def test_NumericConstraints_Equals(twoPars): - - value = 1 - - # Should skip - c = NumericConstraint(twoPars[0][0], "==", value) - c() - assert twoPars[0][0].value_no_call_back == twoPars[1][0] - - # Should update to new value - c = NumericConstraint(twoPars[0][1], "==", value) - c() - assert twoPars[0][1].value_no_call_back == value - - -def test_NumericConstraints_Greater(twoPars): - value = 1.5 - - # Should update to new value - c = NumericConstraint(twoPars[0][0], ">", value) - c() - assert twoPars[0][0].value_no_call_back == value - - # Should skip - c = NumericConstraint(twoPars[0][1], ">", value) - c() - assert twoPars[0][1].value_no_call_back == twoPars[1][1] - - -def test_NumericConstraints_Less(twoPars): - value = 1.5 - - # Should skip - c = NumericConstraint(twoPars[0][0], "<", value) - c() - assert twoPars[0][0].value_no_call_back == twoPars[1][0] - - # Should update to new value - c = NumericConstraint(twoPars[0][1], "<", value) - c() - assert twoPars[0][1].value_no_call_back == value - - -@pytest.mark.parametrize("multiplication_factor", [None, 1, 2, 3, 4.5]) -def test_ObjConstraintMultiply(twoPars, multiplication_factor): - if multiplication_factor is None: - multiplication_factor = 1 - operator_str = "" - else: - operator_str = f"{multiplication_factor}*" - c = ObjConstraint(twoPars[0][0], operator_str, twoPars[0][1]) - c() - assert twoPars[0][0].value_no_call_back == multiplication_factor * twoPars[1][1] - - -@pytest.mark.parametrize("division_factor", [1, 2, 3, 4.5]) -def test_ObjConstraintDivide(twoPars, division_factor): - operator_str = f"{division_factor}/" - c = ObjConstraint(twoPars[0][0], operator_str, twoPars[0][1]) - c() - assert twoPars[0][0].value_no_call_back == division_factor / twoPars[1][1] - - -def test_ObjConstraint_Multiple(threePars): - - p0 = threePars[0][0] - p1 = threePars[0][1] - p2 = threePars[0][2] - - value = 1.5 - - p0.user_constraints["num_1"] = ObjConstraint(p1, "", p0) - p0.user_constraints["num_2"] = ObjConstraint(p2, "", p0) - - p0.value = value - assert p0.value_no_call_back == value - assert p1.value_no_call_back == value - assert p2.value_no_call_back == value - - -def test_ConstraintEnable_Disable(twoPars): - - assert twoPars[0][0].enabled - assert twoPars[0][1].enabled - - c = ObjConstraint(twoPars[0][0], "", twoPars[0][1]) - twoPars[0][0].user_constraints["num_1"] = c - - assert c.enabled - assert twoPars[0][1].enabled - assert not twoPars[0][0].enabled - - c.enabled = False - assert not c.enabled - assert twoPars[0][1].enabled - assert twoPars[0][0].enabled - - c.enabled = True - assert c.enabled - assert twoPars[0][1].enabled - assert not twoPars[0][0].enabled diff --git a/tests/unit_tests/Fitting/test_fitter.py b/tests/unit_tests/Fitting/test_fitter.py index 63783c17..992225ce 100644 --- a/tests/unit_tests/Fitting/test_fitter.py +++ b/tests/unit_tests/Fitting/test_fitter.py @@ -24,42 +24,6 @@ def test_constructor(self, fitter: Fitter): assert fitter._minimizer is None fitter._update_minimizer.assert_called_once_with(AvailableMinimizers.LMFit_leastsq) - def test_fit_constraints(self, fitter: Fitter): - # When - mock_minimizer = MagicMock() - mock_minimizer.fit_constraints = MagicMock(return_value='constraints') - fitter._minimizer = mock_minimizer - - # Then - constraints = fitter.fit_constraints() - - # Expect - assert constraints == 'constraints' - - def test_add_fit_constraint(self, fitter: Fitter): - # When - mock_minimizer = MagicMock() - mock_minimizer.add_fit_constraint = MagicMock() - fitter._minimizer = mock_minimizer - - # Then - fitter.add_fit_constraint('constraints') - - # Expect - mock_minimizer.add_fit_constraint.assert_called_once_with('constraints') - - def test_remove_fit_constraint(self, fitter: Fitter): - # When - mock_minimizer = MagicMock() - mock_minimizer.remove_fit_constraint = MagicMock() - fitter._minimizer = mock_minimizer - - # Then - fitter.remove_fit_constraint(10) - - # Expect - mock_minimizer.remove_fit_constraint.assert_called_once_with(10) - def test_make_model(self, fitter: Fitter): # When mock_minimizer = MagicMock() @@ -128,8 +92,6 @@ def test_create(self, fitter: Fitter, monkeypatch): def test_switch_minimizer(self, fitter: Fitter, monkeypatch): # When mock_minimizer = MagicMock() - mock_minimizer.fit_constraints = MagicMock(return_value='constraints') - mock_minimizer.set_fit_constraint = MagicMock() fitter._minimizer = mock_minimizer mock_string_to_enum = MagicMock(return_value=10) monkeypatch.setattr(easyscience.fitting.fitter, 'from_string_to_enum', mock_string_to_enum) @@ -139,8 +101,6 @@ def test_switch_minimizer(self, fitter: Fitter, monkeypatch): # Expect fitter._update_minimizer.count(2) - mock_minimizer.set_fit_constraint.assert_called_once_with('constraints') - mock_minimizer.fit_constraints.assert_called_once() mock_string_to_enum.assert_called_once_with('great-minimizer') def test_update_minimizer(self, monkeypatch): diff --git a/tests/unit_tests/Objects/test_BaseObj.py b/tests/unit_tests/Objects/test_BaseObj.py index d95d9b0a..67b91e82 100644 --- a/tests/unit_tests/Objects/test_BaseObj.py +++ b/tests/unit_tests/Objects/test_BaseObj.py @@ -148,7 +148,7 @@ def test_baseobj_fit_objects(setup_pars: dict): pass -def test_baseobj_as_dict(setup_pars: dict): +def test_baseobj_as_dict(clear, setup_pars: dict): name = setup_pars["name"] del setup_pars["name"] obj = BaseObj(name, **setup_pars) @@ -159,6 +159,7 @@ def test_baseobj_as_dict(setup_pars: dict): "@class": "BaseObj", "@version": easyscience.__version__, "name": "test", + "unique_name": "BaseObj_0", "par1": { "@module": Parameter.__module__, "@class": Parameter.__name__, @@ -172,6 +173,8 @@ def test_baseobj_as_dict(setup_pars: dict): "unit": "dimensionless", }, "des1": { + "@module": DescriptorNumber.__module__, + "@class": DescriptorNumber.__name__, "@module": DescriptorNumber.__module__, "@class": DescriptorNumber.__name__, "@version": easyscience.__version__, @@ -195,6 +198,8 @@ def test_baseobj_as_dict(setup_pars: dict): "unit": "dimensionless", }, "des2": { + "@module": DescriptorNumber.__module__, + "@class": DescriptorNumber.__name__, "@module": DescriptorNumber.__module__, "@class": DescriptorNumber.__name__, "@version": easyscience.__version__, @@ -261,7 +266,6 @@ def test_baseobj_dir(setup_pars): "encode", "decode", "as_dict", - "constraints", "des1", "des2", "from_dict", diff --git a/tests/unit_tests/Objects/test_Groups.py b/tests/unit_tests/Objects/test_Groups.py index 3850b5f3..b546373f 100644 --- a/tests/unit_tests/Objects/test_Groups.py +++ b/tests/unit_tests/Objects/test_Groups.py @@ -394,22 +394,6 @@ def test_baseCollection_from_dict(cls): assert item1.value == item2.value -@pytest.mark.parametrize("cls", class_constructors) -def test_baseCollection_constraints(cls): - name = "test" - p1 = Parameter("p1", 1) - p2 = Parameter("p2", 2) - - from easyscience.Constraints import ObjConstraint - - p2.user_constraints["testing"] = ObjConstraint(p2, "2*", p1) - - obj = cls(name, p1, p2) - - cons: List[ObjConstraint] = obj.constraints - assert len(cons) == 1 - - @pytest.mark.parametrize("cls", class_constructors) def test_baseCollection_repr(cls): name = "test" diff --git a/tests/unit_tests/Objects/variable/test_parameter.py b/tests/unit_tests/Objects/variable/test_parameter.py index e00350b6..476827ce 100644 --- a/tests/unit_tests/Objects/variable/test_parameter.py +++ b/tests/unit_tests/Objects/variable/test_parameter.py @@ -24,7 +24,6 @@ def parameter(self) -> Parameter: url="url", display_name="display_name", callback=self.mock_callback, - enabled="enabled", parent=None, ) return parameter @@ -40,7 +39,7 @@ def test_init(self, parameter: Parameter): assert parameter._max.value == 10 assert parameter._max.unit == "m" assert parameter._callback == self.mock_callback - assert parameter._enabled == "enabled" + assert parameter._independent == True # From super assert parameter._scalar.value == 1 @@ -69,7 +68,6 @@ def test_init_value_min_exception(self): url="url", display_name="display_name", callback=mock_callback, - enabled="enabled", parent=None, ) @@ -91,7 +89,6 @@ def test_init_value_max_exception(self): url="url", display_name="display_name", callback=mock_callback, - enabled="enabled", parent=None, ) @@ -185,68 +182,12 @@ def test_repr_fixed(self, parameter: Parameter): # Then Expect assert repr(parameter) == "" - def test_bounds(self, parameter: Parameter): - # When Then Expect - assert parameter.bounds == (0, 10) - - def test_set_bounds(self, parameter: Parameter): - # When - self.mock_callback.fget.return_value = 1.0 # Ensure fget returns a scalar value - parameter._enabled = False - parameter._fixed = True - - # Then - parameter.bounds = (-10, 5) - - # Expect - assert parameter.min == -10 - assert parameter.max == 5 - assert parameter._enabled == True - assert parameter._fixed == False - - def test_set_bounds_exception_min(self, parameter: Parameter): - # When - parameter._enabled = False - parameter._fixed = True - - # Then - with pytest.raises(ValueError): - parameter.bounds = (2, 10) - - # Expect - assert parameter.min == 0 - assert parameter.max == 10 - assert parameter._enabled == False - assert parameter._fixed == True - - def test_set_bounds_exception_max(self, parameter: Parameter): - # When - parameter._enabled = False - parameter._fixed = True - - # Then - with pytest.raises(ValueError): - parameter.bounds = (0, 0.1) - - # Expect - assert parameter.min == 0 - assert parameter.max == 10 - assert parameter._enabled == False - assert parameter._fixed == True - - def test_enabled(self, parameter: Parameter): + def test_independent(self, parameter: Parameter): # When - parameter._enabled = True + parameter._independent = True # Then Expect - assert parameter.enabled is True - - def test_set_enabled(self, parameter: Parameter): - # When - parameter.enabled = False - - # Then Expect - assert parameter._enabled is False + assert parameter.independent is True def test_value_match_callback(self, parameter: Parameter): # When @@ -317,7 +258,7 @@ def test_copy(self, parameter: Parameter): assert parameter_copy._description == parameter._description assert parameter_copy._url == parameter._url assert parameter_copy._display_name == parameter._display_name - assert parameter_copy._enabled == parameter._enabled + assert parameter_copy._independent == parameter._independent def test_as_data_dict(self, clear, parameter: Parameter): # When Then @@ -336,7 +277,6 @@ def test_as_data_dict(self, clear, parameter: Parameter): "description": "description", "url": "url", "display_name": "display_name", - "enabled": "enabled", "unique_name": "Parameter_0", } diff --git a/tests/unit_tests/Objects/variable/test_parameter_from_legacy.py b/tests/unit_tests/Objects/variable/test_parameter_from_legacy.py deleted file mode 100644 index f4dcd2ea..00000000 --- a/tests/unit_tests/Objects/variable/test_parameter_from_legacy.py +++ /dev/null @@ -1,424 +0,0 @@ -# SPDX-FileCopyrightText: 2023 EasyScience contributors -# SPDX-License-Identifier: BSD-3-Clause -# © 2021-2023 Contributors to the EasyScience project " - d = Parameter("test", 1, unit="cm") - assert repr(d) == f"<{d.__class__.__name__} 'test': 1.0000 cm, bounds=[-inf:inf]>" - d = Parameter("test", 1, variance=0.1) - assert repr(d) == f"<{d.__class__.__name__} 'test': 1.0000 ± 0.3162, bounds=[-inf:inf]>" - - d = Parameter("test", 1, fixed=True) - assert ( - repr(d) - == f"<{d.__class__.__name__} 'test': 1.0000 (fixed), bounds=[-inf:inf]>" - ) - d = Parameter("test", 1, unit="cm", variance=0.1, fixed=True) - assert ( - repr(d) - == f"<{d.__class__.__name__} 'test': 1.0000 ± 0.3162 cm (fixed), bounds=[-inf:inf]>" - ) - - -def test_parameter_as_dict(): - d = Parameter("test", 1) - result = d.as_dict() - expected = { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "unit": "dimensionless", - } - for key in expected.keys(): - assert result[key] == expected[key] - - # Check that additional arguments work - d = Parameter("test", 1, unit="km", url="https://www.boo.com") - result = d.as_dict() - expected = { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "unit": "km", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "url": "https://www.boo.com", - } - for key in expected.keys(): - assert result[key] == expected[key] - - -def test_item_from_dict(): - reference = { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "unit": "km", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "url": "https://www.boo.com", - } - constructor = Parameter - d = constructor.from_dict(reference) - for key, item in reference.items(): - if key == "callback" or key.startswith("@"): - continue - obtained = getattr(d, key) - assert obtained == item - - -@pytest.mark.parametrize( - "construct", - ( - { - "@module": Parameter.__module__, - "@class": Parameter.__name__, - "@version": easyscience.__version__, - "name": "test", - "unit": "km", - "value": 1.0, - "variance": 0.0, - "min": -np.inf, - "max": np.inf, - "fixed": False, - "url": "https://www.boo.com", - }, - ), - ids=["Parameter"], -) -def test_item_from_Decoder(construct): - - from easyscience.Utils.io.dict import DictSerializer - - d = DictSerializer().decode(construct) - assert d.__class__.__name__ == construct["@class"] - for key, item in construct.items(): - if key == "callback" or key.startswith("@"): - continue - obtained = getattr(d, key) - assert obtained == item - - -@pytest.mark.parametrize("value", (-np.inf, 0, 1.0, 2147483648, np.inf)) -def test_parameter_min(value): - d = Parameter("test", -0.1) - if d.value < value: - with pytest.raises(ValueError): - d.min = value - else: - d.min = value - assert d.min == value - - -@pytest.mark.parametrize("value", [-np.inf, 0, 1.1, 2147483648, np.inf]) -def test_parameter_max(value): - d = Parameter("test", 2147483649) - if d.value > value: - with pytest.raises(ValueError): - d.max = value - else: - d.max = value - assert d.max == value - - -@pytest.mark.parametrize("value", [True, False, 5]) -def test_parameter_fixed(value): - d = Parameter("test", -np.inf) - if isinstance(value, bool): - d.fixed = value - assert d.fixed == value - else: - with pytest.raises(ValueError): - d.fixed = value - - -@pytest.mark.parametrize("value", (-np.inf, -0.1, 0, 1.0, 2147483648, np.inf)) -def test_parameter_error(value): - d = Parameter("test", 1) - if value >= 0: - d.error = value - assert d.error == value - else: - with pytest.raises(ValueError): - d.error = value - - -def _generate_advanced_inputs(): - temp = _generate_inputs() - # These will be the optional parameters - advanced = {"variance": 1.0, "min": -0.1, "max": 2147483648, "fixed": False} - advanced_result = { - "variance": {"name": "variance", "value": advanced["variance"]}, - "min": {"name": "min", "value": advanced["min"]}, - "max": {"name": "max", "value": advanced["max"]}, - "fixed": {"name": "fixed", "value": advanced["fixed"]}, - } - - def create_entry(base, key, value, ref, ref_key=None): - this_temp = deepcopy(base) - for item in base: - test, res = item - new_opt = deepcopy(test[1]) - new_res = deepcopy(res) - if ref_key is None: - ref_key = key - new_res[ref_key] = ref - new_opt[key] = value - this_temp.append(([test[0], new_opt], new_res)) - return this_temp - - for add_opt in advanced.keys(): - if isinstance(advanced[add_opt], list): - for idx, item in enumerate(advanced[add_opt]): - temp = create_entry( - temp, - add_opt, - item, - advanced_result[add_opt]["value"][idx], - ref_key=advanced_result[add_opt]["name"], - ) - else: - temp = create_entry( - temp, - add_opt, - advanced[add_opt], - advanced_result[add_opt]["value"], - ref_key=advanced_result[add_opt]["name"], - ) - return temp - - -@pytest.mark.parametrize("element, expected", _generate_advanced_inputs()) -def test_parameter_advanced_creation(element, expected): - if len(element[0]) > 0: - value = element[0][1] - else: - value = element[1]["value"] - if "min" in element[1].keys(): - if element[1]["min"] > value: - with pytest.raises(ValueError): - d = Parameter(*element[0], **element[1]) - elif "max" in element[1].keys(): - if element[1]["max"] < value: - with pytest.raises(ValueError): - d = Parameter(*element[0], **element[1]) - else: - d = Parameter(*element[0], **element[1]) - for field in expected.keys(): - ref = expected[field] - obtained = getattr(d, field) - assert obtained == ref - - -@pytest.mark.parametrize("value", ("This is ", "a fun ", "test")) -def test_parameter_display_name(value): - p = Parameter("test", 1, display_name=value) - assert p.display_name == value - - -@pytest.mark.parametrize("value", (True, False)) -def test_parameter_bounds(value): - for fixed in (True, False): - p = Parameter("test", 1, enabled=value, fixed=fixed) - assert p.min == -np.inf - assert p.max == np.inf - assert p.fixed == fixed - assert p.bounds == (-np.inf, np.inf) - - p.bounds = (0, 2) - assert p.min == 0 - assert p.max == 2 - assert p.bounds == (0, 2) - assert p.enabled is True - assert p.fixed is False \ No newline at end of file diff --git a/tests/unit_tests/global_object/test_undo_redo.py b/tests/unit_tests/global_object/test_undo_redo.py index 1bcc8841..53c75934 100644 --- a/tests/unit_tests/global_object/test_undo_redo.py +++ b/tests/unit_tests/global_object/test_undo_redo.py @@ -118,7 +118,6 @@ def test_DescriptorStrUndoRedo(): ("error", 5), ("unit", "km/s"), ("display_name", "boom"), - ("enabled", False), ("fixed", False), ("max", 505), ("min", -1), @@ -134,27 +133,23 @@ def test_ParameterUndoRedo(test): e = doUndoRedo(obj, attr, value) assert not e -@pytest.mark.parametrize("value", (True, False)) -def test_Parameter_Bounds_UndoRedo(value): +def test_Parameter_Bounds_UndoRedo(): from easyscience import global_object global_object.stack.enabled = True - p = Parameter("test", 1, enabled=value) - assert p.min == -np.inf - assert p.max == np.inf - assert p.bounds == (-np.inf, np.inf) + parameter = Parameter("test", 1) + assert parameter.min == -np.inf + assert parameter.max == np.inf - p.bounds = (0, 2) - assert p.min == 0 - assert p.max == 2 - assert p.bounds == (0, 2) - assert p.enabled is True + parameter.min = 0 + parameter.max = 2 + assert parameter.min == 0 + assert parameter.max == 2 global_object.stack.undo() - assert p.min == -np.inf - assert p.max == np.inf - assert p.bounds == (-np.inf, np.inf) - assert p.enabled is value + global_object.stack.undo() + assert parameter.min == -np.inf + assert parameter.max == np.inf def test_BaseObjUndoRedo(): diff --git a/tests/unit_tests/utils/io_tests/test_core.py b/tests/unit_tests/utils/io_tests/test_core.py index 3e87d539..2083ac3c 100644 --- a/tests/unit_tests/utils/io_tests/test_core.py +++ b/tests/unit_tests/utils/io_tests/test_core.py @@ -8,7 +8,6 @@ import pytest import easyscience -from easyscience.Objects.ObjectClasses import BaseObj from easyscience.Objects.variable import DescriptorNumber from easyscience.Objects.variable import Parameter @@ -45,7 +44,6 @@ "url": "https://www.boo.com", "description": "", "display_name": "test", - "enabled": True, }, Parameter, ], @@ -123,62 +121,3 @@ def test_variable_as_data_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorN assert len(dif) == 0 check_dict(data_dict, enc_d) - - -class A(BaseObj): - def __init__(self, name: str = "A", **kwargs): - super().__init__(name=name, **kwargs) - - -class B(BaseObj): - def __init__(self, a, b, unique_name): - super(B, self).__init__("B", a=a, unique_name=unique_name) - self.b = b - - -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_as_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: dp_kwargs, - } - - obj = A(**a_kw) - - enc = obj.as_dict() - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_as_data_dict_methods(dp_kwargs: dict, dp_cls: Type[DescriptorNumber]): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - obj = A(**a_kw) - - enc = obj.as_data_dict() - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) diff --git a/tests/unit_tests/utils/io_tests/test_dict.py b/tests/unit_tests/utils/io_tests/test_dict.py index 884f86b6..a9b8ccd4 100644 --- a/tests/unit_tests/utils/io_tests/test_dict.py +++ b/tests/unit_tests/utils/io_tests/test_dict.py @@ -4,17 +4,13 @@ from copy import deepcopy from typing import Type -import numpy as np import pytest -from importlib import metadata from easyscience.Utils.io.dict import DataDictSerializer from easyscience.Utils.io.dict import DictSerializer from easyscience.Objects.variable import DescriptorNumber from easyscience.Objects.ObjectClasses import BaseObj -from .test_core import A -from .test_core import B from .test_core import check_dict from .test_core import dp_param_dict from .test_core import skip_dict @@ -120,144 +116,6 @@ def test_variable_encode_data(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], s check_dict(data_dict, enc_d) -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DictSerializer_encode( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: deepcopy(dp_kwargs), - } - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=DictSerializer) - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DataDictSerializer( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=DataDictSerializer) - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -@pytest.mark.parametrize( - "encoder", [None, DataDictSerializer], ids=["Default", "DataDictSerializer"] -) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_encode_data(dp_kwargs: dict, dp_cls: Type[DescriptorNumber], encoder): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - obj = A(**a_kw) - - enc = obj.encode_data(encoder=encoder) - expected_keys = set(full_d.keys()) - obtained_keys = set(enc.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, enc) - - -def test_custom_class_full_encode_with_numpy(): - class B(BaseObj): - def __init__(self, a, b, unique_name): - super(B, self).__init__("B", a=a, unique_name=unique_name) - self.b = b - # Same as in __init__.py for easyscience - try: - version = metadata.version('easyscience') # 'easyscience' is the name of the package in 'setup.py - except metadata.PackageNotFoundError: - version = '0.0.0' - - obj = B(DescriptorNumber("a", 1.0, unique_name="a"), np.array([1.0, 2.0, 3.0]), unique_name="B_0") - full_enc = obj.encode(encoder=DictSerializer, full_encode=True) - expected = { - "@module": "tests.unit_tests.utils.io_tests.test_dict", - "@class": "B", - "@version": None, - "unique_name": "B_0", - "b": { - "@module": "numpy", - "@class": "array", - "dtype": "float64", - "data": [1.0, 2.0, 3.0], - }, - "a": { - "@module": "easyscience.Objects.variable.descriptor_number", - "@class": "DescriptorNumber", - "@version": version, - "description": "", - "unit": "dimensionless", - "display_name": "a", - "name": "a", - "value": 1.0, - "variance": None, - "unique_name": "a", - "url": "", - }, - } - check_dict(full_enc, expected) - - -def test_custom_class_full_decode_with_numpy(): - global_object.map._clear() - obj = B(DescriptorNumber("a", 1.0), np.array([1.0, 2.0, 3.0]), unique_name="B_0") - full_enc = obj.encode(encoder=DictSerializer, full_encode=True) - global_object.map._clear() - obj2 = B.decode(full_enc, decoder=DictSerializer) - assert obj.name == obj2.name - assert obj.unique_name == obj2.unique_name - assert obj.a.value == obj2.a.value - assert np.all(obj.b == obj2.b) - - ######################################################################################################################## # TESTING DECODING ######################################################################################################################## @@ -325,95 +183,4 @@ def test_group_encode2(): b = BaseObj("outer", b=BaseCollection("test", d0, d1)) d = b.as_dict() - assert isinstance(d["b"], dict) - - -#TODO: do we need/want this test? -# -# @pytest.mark.parametrize(**dp_param_dict) -# def test_custom_class_DictSerializer_decode(dp_kwargs: dict, dp_cls: Type[Descriptor]): -# -# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != '@'} -# -# a_kw = { -# data_dict['name']: dp_cls(**data_dict) -# } -# -# obj = A(**a_kw) -# -# enc = obj.encode(encoder=DictSerializer) -# -# stripped_encode = {k: v for k, v in enc.items() if k[0] != '@'} -# stripped_encode[data_dict['name']] = data_dict -# -# dec = obj.decode(enc, decoder=DictSerializer) -# -# def test_objs(reference_obj, test_obj, in_dict): -# if 'value' in in_dict.keys(): -# in_dict['value'] = in_dict.pop('value') -# if 'units' in in_dict.keys(): -# del in_dict['units'] -# for k in in_dict.keys(): -# if hasattr(reference_obj, k) and hasattr(test_obj, k): -# if isinstance(in_dict[k], dict): -# test_objs(getattr(obj, k), getattr(test_obj, k), in_dict[k]) -# assert getattr(obj, k) == getattr(dec, k) -# else: -# raise AttributeError(f"{k} not found in decoded object") -# test_objs(obj, dec, stripped_encode) -# -# -# @pytest.mark.parametrize(**skip_dict) -# @pytest.mark.parametrize(**dp_param_dict) -# def test_custom_class_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[Descriptor], skip): -# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != '@'} -# -# a_kw = { -# data_dict['name']: dp_cls(**data_dict) -# } -# -# full_d = { -# "name": "A", -# dp_kwargs['name']: data_dict -# } -# -# full_d = recursive_remove(full_d, skip) -# -# obj = A(**a_kw) -# -# enc = obj.encode(skip=skip, encoder=DataDictSerializer) -# expected_keys = set(full_d.keys()) -# obtained_keys = set(enc.keys()) -# -# dif = expected_keys.difference(obtained_keys) -# -# assert len(dif) == 0 -# -# check_dict(full_d, enc) -# -# -# @pytest.mark.parametrize('encoder', [None, DataDictSerializer], ids=['Default', 'DataDictSerializer']) -# @pytest.mark.parametrize(**dp_param_dict) -# def test_custom_class_encode_data(dp_kwargs: dict, dp_cls: Type[Descriptor], encoder): -# data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != '@'} -# -# a_kw = { -# data_dict['name']: dp_cls(**data_dict) -# } -# -# full_d = { -# "name": "A", -# dp_kwargs['name']: data_dict -# } -# -# obj = A(**a_kw) -# -# enc = obj.encode_data(encoder=encoder) -# expected_keys = set(full_d.keys()) -# obtained_keys = set(enc.keys()) -# -# dif = expected_keys.difference(obtained_keys) -# -# assert len(dif) == 0 -# -# check_dict(full_d, enc) + assert isinstance(d["b"], dict) \ No newline at end of file diff --git a/tests/unit_tests/utils/io_tests/test_json.py b/tests/unit_tests/utils/io_tests/test_json.py index cec6e4c0..54f9ccb9 100644 --- a/tests/unit_tests/utils/io_tests/test_json.py +++ b/tests/unit_tests/utils/io_tests/test_json.py @@ -11,7 +11,6 @@ from easyscience.Utils.io.json import JsonSerializer from easyscience.Objects.variable import DescriptorNumber -from .test_core import A from .test_core import check_dict from .test_core import dp_param_dict from .test_core import skip_dict @@ -93,78 +92,6 @@ def test_variable_DataDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNum check_dict(data_dict, enc_d) - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DictSerializer_encode( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: deepcopy(dp_kwargs), - } - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=JsonSerializer) - assert isinstance(enc, str) - - # We can test like this as we don't have "complex" objects yet - dec = json.loads(enc) - - expected_keys = set(full_d.keys()) - obtained_keys = set(dec.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, dec) - - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_DataDictSerializer( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = {"name": "A", dp_kwargs["name"]: data_dict} - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=JsonDataSerializer) - dec = json.loads(enc) - - expected_keys = set(full_d.keys()) - obtained_keys = set(dec.keys()) - - dif = expected_keys.difference(obtained_keys) - - assert len(dif) == 0 - - check_dict(full_d, dec) - - # ######################################################################################################################## # # TESTING DECODING # ######################################################################################################################## diff --git a/tests/unit_tests/utils/io_tests/test_xml.py b/tests/unit_tests/utils/io_tests/test_xml.py index 2edb761e..b382bf89 100644 --- a/tests/unit_tests/utils/io_tests/test_xml.py +++ b/tests/unit_tests/utils/io_tests/test_xml.py @@ -11,7 +11,6 @@ from easyscience.Utils.io.xml import XMLSerializer from easyscience.Objects.variable import DescriptorNumber -from .test_core import A from .test_core import dp_param_dict from .test_core import skip_dict from easyscience import global_object @@ -65,39 +64,6 @@ def test_variable_XMLDictSerializer(dp_kwargs: dict, dp_cls: Type[DescriptorNumb assert data_xml.tag == "data" recursive_test(data_xml, ref_encode) - -@pytest.mark.parametrize(**skip_dict) -@pytest.mark.parametrize(**dp_param_dict) -def test_custom_class_XMLDictSerializer_encode( - dp_kwargs: dict, dp_cls: Type[DescriptorNumber], skip -): - data_dict = {k: v for k, v in dp_kwargs.items() if k[0] != "@"} - - a_kw = {data_dict["name"]: dp_cls(**data_dict)} - - full_d = { - "@module": A.__module__, - "@class": A.__name__, - "@version": None, - "name": "A", - dp_kwargs["name"]: deepcopy(dp_kwargs), - } - - if not isinstance(skip, list): - skip = [skip] - - full_d = recursive_remove(full_d, skip) - - obj = A(**a_kw) - - enc = obj.encode(skip=skip, encoder=XMLSerializer) - ref_encode = obj.encode(skip=skip) - assert isinstance(enc, str) - data_xml = ET.XML(enc) - assert data_xml.tag == "data" - recursive_test(data_xml, ref_encode) - - # ######################################################################################################################## # # TESTING DECODING # ########################################################################################################################