diff --git a/manim/mobject/mobject.py b/manim/mobject/mobject.py index a914da57e0..53adade1c5 100644 --- a/manim/mobject/mobject.py +++ b/manim/mobject/mobject.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """Base classes for objects that can be displayed.""" @@ -16,19 +18,19 @@ from pathlib import Path from typing import ( TYPE_CHECKING, + Any, Callable, - Dict, - Iterable, - List, - Optional, + Mapping, + Protocol, Sequence, - Tuple, - Type, TypeVar, Union, + cast, ) import numpy as np + +# TODO(types): colour has no type annotations. from colour import Color from .. import config @@ -51,11 +53,39 @@ # TODO: Explain array_attrs -Updater = Union[Callable[["Mobject"], None], Callable[["Mobject", float], None]] -T = TypeVar("T", bound="Mobject") - if TYPE_CHECKING: + from PIL.Image import Image + from ..animation.animation import Animation + from ..camera.camera import Camera + from ..utils.bezier import Interpolable + + # Copy from typeshed. + class SupportsLessThan(Protocol): + def __lt__(self, __other: Any) -> bool: + ... + + +# Generic type +T = TypeVar("T") + +Updater = Union[Callable[["Mobject"], None], Callable[["Mobject", float], None]] +# Any subclass of Mobject. Most Mobject methods `return self` so this is used +# so that e.g. Circle.scale(2) returns a Circle, not an Mobject. +MOS = TypeVar("MOS", bound="Mobject") +# A method on Mobject that takes any parameters and returns an Animation. +# Ideally this would be Callable[[Mobject, ...], Animation] but that is not +# actually valid syntax yet. +AnimationMethod = Callable[..., Animation] + +# A colour. You might think this should be Colors, but actually that enum is +# really only used as a container for strings. +ColorStr = str + +# Function that modifies a point. +# TODO(types): This can't be quite correct because R3_func returns a list, +# not an np.ndarray. +PointFunction = Callable[[np.ndarray], np.ndarray] class Mobject: @@ -78,20 +108,44 @@ class Mobject: """ + color: Color | ColorStr | None + name: str + dim: int + target: Mobject | None + z_index: float + point_hash: int | None + submobjects: list[Mobject] + updaters: list[Updater] + updating_suspended: bool + points: np.ndarray + + # Note, this attribute is only present if the object has been cloned. + original_id: str + animation_overrides = {} + # This attribute is only present if save_state() has been called. + saved_state: Mobject | None + @classmethod - def __init_subclass__(cls, **kwargs): + def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__(**kwargs) - cls.animation_overrides: Dict[ - Type["Animation"], - Callable[["Mobject"], "Animation"], + cls.animation_overrides: dict[ + type[Animation], + AnimationMethod, ] = {} cls._add_intrinsic_animation_overrides() cls._original__init__ = cls.__init__ - def __init__(self, color=WHITE, name=None, dim=3, target=None, z_index=0): + def __init__( + self, + color: ColorStr | None = WHITE, + name: str | None = None, + dim: int = 3, + target: Any | None = None, + z_index: int = 0, + ): self.color = Color(color) if color else None self.name = self.__class__.__name__ if name is None else name self.dim = dim @@ -108,8 +162,8 @@ def __init__(self, color=WHITE, name=None, dim=3, target=None, z_index=0): @classmethod def animation_override_for( cls, - animation_class: Type["Animation"], - ) -> "Optional[Callable[[Mobject, ...], Animation]]": + animation_class: type[Animation], + ) -> AnimationMethod | None: """Returns the function defining a specific animation override for this class. Parameters @@ -129,7 +183,7 @@ def animation_override_for( return None @classmethod - def _add_intrinsic_animation_overrides(cls): + def _add_intrinsic_animation_overrides(cls) -> None: """Initializes animation overrides marked with the :func:`~.override_animation` decorator. """ @@ -146,9 +200,9 @@ def _add_intrinsic_animation_overrides(cls): @classmethod def add_animation_override( cls, - animation_class: Type["Animation"], - override_func: "Callable[[Mobject, ...], Animation]", - ): + animation_class: type[Animation], + override_func: AnimationMethod, + ) -> None: """Add an animation override. This does not apply to subclasses. @@ -221,7 +275,7 @@ def construct(self): cls.__init__ = cls._original__init__ @property - def animate(self): + def animate(self) -> _AnimationBuilder: """Used to animate the application of a method. .. warning:: @@ -308,7 +362,7 @@ def construct(self): """ return _AnimationBuilder(self) - def __deepcopy__(self, clone_from_id): + def __deepcopy__(self: MOS, clone_from_id: dict[int, Mobject]) -> MOS: cls = self.__class__ result = cls.__new__(cls) clone_from_id[id(self)] = result @@ -317,17 +371,20 @@ def __deepcopy__(self, clone_from_id): result.original_id = str(id(self)) return result - def __repr__(self): + def __repr__(self) -> str: + # TODO(types): config has no type annotations. if config["renderer"] == "opengl": return super().__repr__() else: return str(self.name) - def reset_points(self): + def reset_points(self) -> None: """Sets :attr:`points` to be an empty array.""" + # TODO(types): Numpy does have type hints but the type is partially + # unknown. self.points = np.zeros((0, self.dim)) - def init_colors(self): + def init_colors(self) -> None: """Initializes the colors. Gets called upon creation. This is an empty method that can be implemented by @@ -335,7 +392,7 @@ def init_colors(self): """ pass - def generate_points(self): + def generate_points(self) -> None: """Initializes :attr:`points` and therefore the shape. Gets called upon creation. This is an empty method that can be implemented by @@ -343,7 +400,7 @@ def generate_points(self): """ pass - def add(self, *mobjects: "Mobject") -> "Mobject": + def add(self: MOS, *mobjects: Mobject) -> MOS: """Add mobjects as submobjects. The mobjects are added to :attr:`submobjects`. @@ -403,20 +460,24 @@ def add(self, *mobjects: "Mobject") -> "Mobject": """ for m in mobjects: - if not isinstance(m, Mobject): + # This type check is not required according to the type annotations + # but it was clearly useful in the past (not everyone uses + # types properly), so I'm leaving it. The type: ignore is to + # suppress warnings about it being unnecessary. + if not isinstance(m, Mobject): # type: ignore raise TypeError("All submobjects must be of type Mobject") if m is self: raise ValueError("Mobject cannot contain self") self.submobjects = list_update(self.submobjects, mobjects) return self - def __add__(self, mobject): + def __add__(self, mobject: Mobject) -> Mobject: raise NotImplementedError - def __iadd__(self, mobject): + def __iadd__(self, mobject: Mobject) -> Mobject: raise NotImplementedError - def add_to_back(self, *mobjects: "Mobject") -> "Mobject": + def add_to_back(self: MOS, *mobjects: Mobject) -> MOS: """Add all passed mobjects to the back of the submobjects. If :attr:`submobjects` already contains the given mobjects, they just get moved @@ -464,7 +525,7 @@ def add_to_back(self, *mobjects: "Mobject") -> "Mobject": raise ValueError("A mobject shouldn't contain itself") for mobject in mobjects: - if not isinstance(mobject, Mobject): + if not isinstance(mobject, Mobject): # type: ignore raise TypeError("All submobjects must be of type Mobject") self.remove(*mobjects) @@ -472,7 +533,7 @@ def add_to_back(self, *mobjects: "Mobject") -> "Mobject": self.submobjects = list(dict.fromkeys(mobjects)) + self.submobjects return self - def remove(self, *mobjects: "Mobject") -> "Mobject": + def remove(self: MOS, *mobjects: Mobject) -> MOS: """Remove :attr:`submobjects`. The mobjects are removed from :attr:`submobjects`, if they exist. @@ -499,13 +560,13 @@ def remove(self, *mobjects: "Mobject") -> "Mobject": self.submobjects.remove(mobject) return self - def __sub__(self, other): + def __sub__(self, other: Mobject) -> Mobject: raise NotImplementedError - def __isub__(self, other): + def __isub__(self, other: Mobject) -> Mobject: raise NotImplementedError - def set(self, **kwargs) -> "Mobject": + def set(self: MOS, **kwargs: Any) -> MOS: """Sets attributes. Mainly to be used along with :attr:`animate` to @@ -560,7 +621,7 @@ def set(self, **kwargs) -> "Mobject": return self - def __getattr__(self, attr): + def __getattr__(self, attr: str): # Add automatic compatibility layer # between properties and get_* and set_* # methods. @@ -572,7 +633,7 @@ def __getattr__(self, attr): # Remove the "get_" prefix to_get = attr[4:] - def getter(self): + def getter(self: Mobject): warnings.warn( "This method is not guaranteed to stay around. Please prefer " "getting the attribute normally.", @@ -589,7 +650,7 @@ def getter(self): # Remove the "set_" prefix to_set = attr[4:] - def setter(self, value): + def setter(self: Mobject, value: Any): warnings.warn( "This method is not guaranteed to stay around. Please prefer " "setting the attribute normally or with Mobject.set().", @@ -608,7 +669,7 @@ def setter(self, value): raise AttributeError(f"{type(self).__name__} object has no attribute '{attr}'") @property - def width(self): + def width(self) -> float: """The width of the mobject. Returns @@ -641,11 +702,11 @@ def construct(self): return self.length_over_dim(0) @width.setter - def width(self, value): + def width(self, value: float): self.scale_to_fit_width(value) @property - def height(self): + def height(self) -> float: """The height of the mobject. Returns @@ -678,11 +739,11 @@ def construct(self): return self.length_over_dim(1) @height.setter - def height(self, value): + def height(self, value: float): self.scale_to_fit_height(value) @property - def depth(self): + def depth(self) -> float: """The depth of the mobject. Returns @@ -699,38 +760,39 @@ def depth(self): return self.length_over_dim(2) @depth.setter - def depth(self, value): + def depth(self, value: float): self.scale_to_fit_depth(value) - def get_array_attrs(self): + def get_array_attrs(self) -> list[str]: return ["points"] - def apply_over_attr_arrays(self, func): + def apply_over_attr_arrays(self: MOS, func: Callable[[Any], Any]) -> MOS: for attr in self.get_array_attrs(): setattr(self, attr, func(getattr(self, attr))) return self # Displaying - def get_image(self, camera=None): + def get_image(self, camera: Camera | None = None) -> Image: if camera is None: from ..camera.camera import Camera + # TODO(types): Camera has no type hints. camera = Camera() camera.capture_mobject(self) return camera.get_image() - def show(self, camera=None): + def show(self, camera: Camera | None = None) -> None: self.get_image(camera=camera).show() - def save_image(self, name=None): + def save_image(self, name: str | None = None) -> None: """Saves an image of only this :class:`Mobject` at its position to a png file.""" self.get_image().save( Path(config.get_dir("video_dir")).joinpath((name or str(self)) + ".png"), ) - def copy(self: T) -> T: + def copy(self: MOS) -> MOS: """Create and return an identical copy of the :class:`Mobject` including all :attr:`submobjects`. @@ -745,7 +807,7 @@ def copy(self: T) -> T: """ return copy.deepcopy(self) - def generate_target(self, use_deepcopy=False): + def generate_target(self: MOS, use_deepcopy: bool = False) -> MOS: self.target = None # Prevent unbounded linear recursion if use_deepcopy: self.target = copy.deepcopy(self) @@ -755,7 +817,7 @@ def generate_target(self, use_deepcopy=False): # Updating - def update(self, dt: float = 0, recursive: bool = True) -> "Mobject": + def update(self: MOS, dt: float = 0, recursive: bool = True) -> MOS: """Apply all updaters. Does nothing if updating is suspended. @@ -783,6 +845,8 @@ def update(self, dt: float = 0, recursive: bool = True) -> "Mobject": return self for updater in self.updaters: parameters = get_parameters(updater) + # TODO(types): updater may take 1 or 2 parameters but we don't check + # that here. if "dt" in parameters: updater(self, dt) else: @@ -792,7 +856,7 @@ def update(self, dt: float = 0, recursive: bool = True) -> "Mobject": submob.update(dt, recursive) return self - def get_time_based_updaters(self) -> List[Updater]: + def get_time_based_updaters(self) -> list[Updater]: """Return all updaters using the ``dt`` parameter. The updaters use this parameter as the input for difference in time. @@ -824,9 +888,12 @@ def has_time_based_updater(self) -> bool: :meth:`get_time_based_updaters` """ - return any("dt" in get_parameters(updater) for updater in self.updaters) + for updater in self.updaters: + if "dt" in get_parameters(updater): + return True + return False - def get_updaters(self) -> List[Updater]: + def get_updaters(self) -> list[Updater]: """Return all updaters. Returns @@ -842,15 +909,15 @@ def get_updaters(self) -> List[Updater]: """ return self.updaters - def get_family_updaters(self): + def get_family_updaters(self) -> list[Updater]: return list(it.chain(*(sm.get_updaters() for sm in self.get_family()))) def add_updater( - self, + self: MOS, update_function: Updater, - index: Optional[int] = None, + index: int | None = None, call_updater: bool = False, - ) -> "Mobject": + ) -> MOS: """Add an update function to this mobject. Update functions, or updaters in short, are functions that are applied to the @@ -917,10 +984,11 @@ def construct(self): else: self.updaters.insert(index, update_function) if call_updater: + # TODO(types): update_function may take only 1 parameter. update_function(self, 0) return self - def remove_updater(self, update_function: Updater) -> "Mobject": + def remove_updater(self: MOS, update_function: Updater) -> MOS: """Remove an updater. If the same updater is applied multiple times, every instance gets removed. @@ -947,7 +1015,7 @@ def remove_updater(self, update_function: Updater) -> "Mobject": self.updaters.remove(update_function) return self - def clear_updaters(self, recursive: bool = True) -> "Mobject": + def clear_updaters(self: MOS, recursive: bool = True) -> MOS: """Remove every updater. Parameters @@ -973,7 +1041,7 @@ def clear_updaters(self, recursive: bool = True) -> "Mobject": submob.clear_updaters() return self - def match_updaters(self, mobject: "Mobject") -> "Mobject": + def match_updaters(self: MOS, mobject: Mobject) -> MOS: """Match the updaters of the given mobject. Parameters @@ -1003,7 +1071,7 @@ def match_updaters(self, mobject: "Mobject") -> "Mobject": self.add_updater(updater) return self - def suspend_updating(self, recursive: bool = True) -> "Mobject": + def suspend_updating(self: MOS, recursive: bool = True) -> MOS: """Disable updating from updaters and animations. @@ -1030,7 +1098,7 @@ def suspend_updating(self, recursive: bool = True) -> "Mobject": submob.suspend_updating(recursive) return self - def resume_updating(self, recursive: bool = True) -> "Mobject": + def resume_updating(self: MOS, recursive: bool = True) -> MOS: """Enable updating from updaters and animations. Parameters @@ -1058,7 +1126,7 @@ def resume_updating(self, recursive: bool = True) -> "Mobject": # Transforming operations - def apply_to_family(self, func: Callable[["Mobject"], None]) -> "Mobject": + def apply_to_family(self, func: Callable[[Mobject], None]) -> None: """Apply a function to ``self`` and every submobject with points recursively. Parameters @@ -1080,7 +1148,7 @@ def apply_to_family(self, func: Callable[["Mobject"], None]) -> "Mobject": for mob in self.family_members_with_points(): func(mob) - def shift(self, *vectors: np.ndarray) -> "Mobject": + def shift(self: MOS, *vectors: np.ndarray) -> MOS: """Shift by the given vectors. Parameters @@ -1106,7 +1174,7 @@ def shift(self, *vectors: np.ndarray) -> "Mobject": return self - def scale(self, scale_factor: float, **kwargs) -> "Mobject": + def scale(self: MOS, scale_factor: float, **kwargs: np.ndarray | None) -> MOS: r"""Scale the size by a factor. Default behavior is to scale about the center of the mobject. @@ -1152,17 +1220,21 @@ def construct(self): ) return self - def rotate_about_origin(self, angle, axis=OUT, axes=[]): + def rotate_about_origin( + self: MOS, + angle: float, + axis: np.ndarray = OUT, + ) -> MOS: """Rotates the :class:`~.Mobject` about the ORIGIN, which is at [0,0,0].""" return self.rotate(angle, axis, about_point=ORIGIN) def rotate( - self, - angle, - axis=OUT, - about_point: Optional[Sequence[float]] = None, - **kwargs, - ): + self: MOS, + angle: float, + axis: np.ndarray = OUT, + about_point: np.ndarray | None = None, + **kwargs: np.ndarray | None, + ) -> MOS: """Rotates the :class:`~.Mobject` about a certain point.""" rot_matrix = rotation_matrix(angle, axis) self.apply_points_function_about_point( @@ -1170,7 +1242,7 @@ def rotate( ) return self - def flip(self, axis=UP, **kwargs): + def flip(self: MOS, axis: np.ndarray = UP, **kwargs: np.ndarray | None) -> MOS: """Flips/Mirrors an mobject about its center. Examples @@ -1189,15 +1261,17 @@ def construct(self): """ return self.rotate(TAU / 2, axis, **kwargs) - def stretch(self, factor, dim, **kwargs): - def func(points): + def stretch(self: MOS, factor: float, dim: int, **kwargs: np.ndarray | None) -> MOS: + def func(points: np.ndarray) -> np.ndarray: points[:, dim] *= factor return points self.apply_points_function_about_point(func, **kwargs) return self - def apply_function(self, function, **kwargs): + def apply_function( + self: MOS, function: PointFunction, **kwargs: np.ndarray | None + ) -> MOS: # Default to applying matrix about the origin, not mobjects center if len(kwargs) == 0: kwargs["about_point"] = ORIGIN @@ -1206,16 +1280,19 @@ def apply_function(self, function, **kwargs): ) return self - def apply_function_to_position(self, function): + def apply_function_to_position(self: MOS, function: PointFunction) -> MOS: self.move_to(function(self.get_center())) return self - def apply_function_to_submobject_positions(self, function): + def apply_function_to_submobject_positions( + self: MOS, + function: PointFunction, + ) -> MOS: for submob in self.submobjects: submob.apply_function_to_position(function) return self - def apply_matrix(self, matrix, **kwargs): + def apply_matrix(self: MOS, matrix: np.ndarray, **kwargs: np.ndarray | None) -> MOS: # Default to applying matrix about the origin, not mobjects center if ("about_point" not in kwargs) and ("about_edge" not in kwargs): kwargs["about_point"] = ORIGIN @@ -1227,7 +1304,9 @@ def apply_matrix(self, matrix, **kwargs): ) return self - def apply_complex_function(self, function, **kwargs): + def apply_complex_function( + self: MOS, function: Callable[[complex], complex], **kwargs: np.ndarray | None + ) -> MOS: """Applies a complex function to a :class:`Mobject`. The x and y coordinates correspond to the real and imaginary parts respectively. @@ -1254,14 +1333,19 @@ def construct(self): self.play(t.animate.set_value(TAU), run_time=3) """ - def R3_func(point): + def R3_func(point: np.ndarray) -> list[float]: x, y, z = point xy_complex = function(complex(x, y)) return [xy_complex.real, xy_complex.imag, z] return self.apply_function(R3_func) - def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): + def wag( + self: MOS, + direction: np.ndarray = RIGHT, + axis: np.ndarray = DOWN, + wag_factor: float = 1.0, + ) -> MOS: for mob in self.family_members_with_points(): alphas = np.dot(mob.points, np.transpose(axis)) alphas -= min(alphas) @@ -1273,15 +1357,15 @@ def wag(self, direction=RIGHT, axis=DOWN, wag_factor=1.0): ) return self - def reverse_points(self): + def reverse_points(self: MOS) -> MOS: for mob in self.family_members_with_points(): mob.apply_over_attr_arrays(lambda arr: np.array(list(reversed(arr)))) return self - def repeat(self, count: int): + def repeat(self: MOS, count: int) -> MOS: """This can make transition animations nicer""" - def repeat_array(array): + def repeat_array(array: np.ndarray) -> np.ndarray: return reduce(lambda a1, a2: np.append(a1, a2, axis=0), [array] * count) for mob in self.family_members_with_points(): @@ -1293,11 +1377,11 @@ def repeat_array(array): # above methods def apply_points_function_about_point( - self, - func, - about_point=None, - about_edge=None, - ): + self: MOS, + func: PointFunction, + about_point: np.ndarray | None = None, + about_edge: np.ndarray | None = None, + ) -> MOS: if about_point is None: if about_edge is None: about_edge = ORIGIN @@ -1313,7 +1397,7 @@ def apply_points_function_about_point( until="v0.12.0", replacement="rotate", ) - def rotate_in_place(self, angle, axis=OUT): + def rotate_in_place(self: MOS, angle: float, axis: np.ndarray = OUT) -> MOS: # redundant with default behavior of rotate now. return self.rotate(angle, axis=axis) @@ -1322,7 +1406,9 @@ def rotate_in_place(self, angle, axis=OUT): until="v0.12.0", replacement="scale", ) - def scale_in_place(self, scale_factor, **kwargs): + def scale_in_place( + self: MOS, scale_factor: float, **kwargs: np.ndarray | None + ) -> MOS: # Redundant with default behavior of scale now. return self.scale(scale_factor, **kwargs) @@ -1331,21 +1417,25 @@ def scale_in_place(self, scale_factor, **kwargs): until="v0.12.0", replacement="scale", ) - def scale_about_point(self, scale_factor, point): + def scale_about_point(self: MOS, scale_factor: float, point: np.ndarray) -> MOS: # Redundant with default behavior of scale now. return self.scale(scale_factor, about_point=point) - def pose_at_angle(self, **kwargs): + def pose_at_angle(self: MOS, **kwargs: np.ndarray | None) -> MOS: self.rotate(TAU / 14, RIGHT + UP, **kwargs) return self # Positioning methods - def center(self): + def center(self: MOS) -> MOS: self.shift(-self.get_center()) return self - def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def align_on_border( + self: MOS, + direction: np.ndarray, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> MOS: """Direction just needs to be a vector pointing towards side or corner in the 2d plane. """ @@ -1360,22 +1450,30 @@ def align_on_border(self, direction, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): self.shift(shift_val) return self - def to_corner(self, corner=LEFT + DOWN, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_corner( + self: MOS, + corner: np.ndarray = LEFT + DOWN, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> MOS: return self.align_on_border(corner, buff) - def to_edge(self, edge=LEFT, buff=DEFAULT_MOBJECT_TO_EDGE_BUFFER): + def to_edge( + self: MOS, + edge: np.ndarray = LEFT, + buff: float = DEFAULT_MOBJECT_TO_EDGE_BUFFER, + ) -> MOS: return self.align_on_border(edge, buff) def next_to( - self, - mobject_or_point, - direction=RIGHT, - buff=DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - aligned_edge=ORIGIN, - submobject_to_align=None, - index_of_submobject_to_align=None, - coor_mask=np.array([1, 1, 1]), - ): + self: MOS, + mobject_or_point: Mobject | np.ndarray, + direction: np.ndarray = RIGHT, + buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, + aligned_edge: np.ndarray = ORIGIN, + submobject_to_align: Mobject | None = None, + index_of_submobject_to_align: int | None = None, + coor_mask: np.ndarray = np.array([1, 1, 1]), + ) -> MOS: """Move this :class:`~.Mobject` next to another's :class:`~.Mobject` or coordinate. Examples @@ -1415,10 +1513,14 @@ def construct(self): self.shift((target_point - point_to_align + buff * direction) * coor_mask) return self - def shift_onto_screen(self, **kwargs): - space_lengths = [config["frame_x_radius"], config["frame_y_radius"]] + def shift_onto_screen(self: MOS, **kwargs) -> MOS: + space_lengths: list[float] = [ + config["frame_x_radius"], + config["frame_y_radius"], + ] for vect in UP, DOWN, LEFT, RIGHT: dim = np.argmax(np.abs(vect)) + # TODO(types): Why not just use a keyword parameter? buff = kwargs.get("buff", DEFAULT_MOBJECT_TO_EDGE_BUFFER) max_val = space_lengths[dim] - buff edge_center = self.get_edge_center(vect) @@ -1426,7 +1528,7 @@ def shift_onto_screen(self, **kwargs): self.to_edge(vect, **kwargs) return self - def is_off_screen(self): + def is_off_screen(self) -> bool: if self.get_left()[0] > config["frame_x_radius"]: return True if self.get_right()[0] < -config["frame_x_radius"]: @@ -1437,7 +1539,12 @@ def is_off_screen(self): return True return False - def stretch_about_point(self, factor, dim, point): + def stretch_about_point( + self: MOS, + factor: float, + dim: int, + point: np.ndarray, + ) -> MOS: return self.stretch(factor, dim, about_point=point) @deprecated( @@ -1445,11 +1552,17 @@ def stretch_about_point(self, factor, dim, point): until="v0.12.0", replacement="stretch", ) - def stretch_in_place(self, factor, dim): + def stretch_in_place(self: MOS, factor: float, dim: int) -> MOS: # Now redundant with stretch return self.stretch(factor, dim) - def rescale_to_fit(self, length, dim, stretch=False, **kwargs): + def rescale_to_fit( + self: MOS, + length: float, + dim: int, + stretch: bool = False, + **kwargs: np.ndarray | None, + ) -> MOS: old_length = self.length_over_dim(dim) if old_length == 0: return self @@ -1459,7 +1572,7 @@ def rescale_to_fit(self, length, dim, stretch=False, **kwargs): self.scale(length / old_length, **kwargs) return self - def scale_to_fit_width(self, width, **kwargs): + def scale_to_fit_width(self: MOS, width: float, **kwargs: np.ndarray | None) -> MOS: """Scales the :class:`~.Mobject` to fit a width while keeping height/depth proportional. Returns @@ -1485,7 +1598,9 @@ def scale_to_fit_width(self, width, **kwargs): return self.rescale_to_fit(width, 0, stretch=False, **kwargs) - def stretch_to_fit_width(self, width, **kwargs): + def stretch_to_fit_width( + self: MOS, width: float, **kwargs: np.ndarray | None + ) -> MOS: """Stretches the :class:`~.Mobject` to fit a width, not keeping height/depth proportional. Returns @@ -1511,7 +1626,9 @@ def stretch_to_fit_width(self, width, **kwargs): return self.rescale_to_fit(width, 0, stretch=True, **kwargs) - def scale_to_fit_height(self, height, **kwargs): + def scale_to_fit_height( + self: MOS, height: float, **kwargs: np.ndarray | None + ) -> MOS: """Scales the :class:`~.Mobject` to fit a height while keeping width/depth proportional. Returns @@ -1537,7 +1654,9 @@ def scale_to_fit_height(self, height, **kwargs): return self.rescale_to_fit(height, 1, stretch=False, **kwargs) - def stretch_to_fit_height(self, height, **kwargs): + def stretch_to_fit_height( + self: MOS, height: float, **kwargs: np.ndarray | None + ) -> MOS: """Stretches the :class:`~.Mobject` to fit a height, not keeping width/depth proportional. Returns @@ -1563,47 +1682,56 @@ def stretch_to_fit_height(self, height, **kwargs): return self.rescale_to_fit(height, 1, stretch=True, **kwargs) - def scale_to_fit_depth(self, depth, **kwargs): + def scale_to_fit_depth(self: MOS, depth: float, **kwargs: np.ndarray | None) -> MOS: """Scales the :class:`~.Mobject` to fit a depth while keeping width/height proportional.""" return self.rescale_to_fit(depth, 2, stretch=False, **kwargs) - def stretch_to_fit_depth(self, depth, **kwargs): + def stretch_to_fit_depth( + self: MOS, depth: float, **kwargs: np.ndarray | None + ) -> MOS: """Stretches the :class:`~.Mobject` to fit a depth, not keeping width/height proportional.""" return self.rescale_to_fit(depth, 2, stretch=True, **kwargs) - def set_coord(self, value, dim, direction=ORIGIN): + def set_coord( + self: MOS, + value: float, + dim: int, + direction: np.ndarray = ORIGIN, + ) -> MOS: curr = self.get_coord(dim, direction) shift_vect = np.zeros(self.dim) shift_vect[dim] = value - curr self.shift(shift_vect) return self - def set_x(self, x, direction=ORIGIN): + def set_x(self: MOS, x: float, direction: np.ndarray = ORIGIN) -> MOS: """Set x value of the center of the :class:`~.Mobject` (``int`` or ``float``)""" return self.set_coord(x, 0, direction) - def set_y(self, y, direction=ORIGIN): + def set_y(self: MOS, y: float, direction: np.ndarray = ORIGIN) -> MOS: """Set y value of the center of the :class:`~.Mobject` (``int`` or ``float``)""" return self.set_coord(y, 1, direction) - def set_z(self, z, direction=ORIGIN): + def set_z(self: MOS, z: float, direction: np.ndarray = ORIGIN) -> MOS: """Set z value of the center of the :class:`~.Mobject` (``int`` or ``float``)""" return self.set_coord(z, 2, direction) - def space_out_submobjects(self, factor=1.5, **kwargs): + def space_out_submobjects( + self: MOS, factor: float = 1.5, **kwargs: np.ndarray | None + ) -> MOS: self.scale(factor, **kwargs) for submob in self.submobjects: submob.scale(1.0 / factor) return self def move_to( - self, - point_or_mobject, - aligned_edge=ORIGIN, - coor_mask=np.array([1, 1, 1]), - ): + self: MOS, + point_or_mobject: Mobject | np.ndarray, + aligned_edge: np.ndarray = ORIGIN, + coor_mask: np.ndarray = np.array([1, 1, 1]), + ) -> MOS: """Move center of the :class:`~.Mobject` to certain coordinate.""" if isinstance(point_or_mobject, Mobject): target = point_or_mobject.get_critical_point(aligned_edge) @@ -1613,7 +1741,12 @@ def move_to( self.shift((target - point_to_align) * coor_mask) return self - def replace(self, mobject, dim_to_match=0, stretch=False): + def replace( + self: MOS, + mobject: Mobject, + dim_to_match: int = 0, + stretch: bool = False, + ) -> MOS: if not mobject.get_num_points() and not mobject.submobjects: raise Warning("Attempting to replace mobject with no points") if stretch: @@ -1629,18 +1762,22 @@ def replace(self, mobject, dim_to_match=0, stretch=False): return self def surround( - self, - mobject: "Mobject", - dim_to_match=0, - stretch=False, - buff=MED_SMALL_BUFF, - ): + self: MOS, + mobject: Mobject, + dim_to_match: int = 0, + stretch: bool = False, + buff: float = MED_SMALL_BUFF, + ) -> MOS: self.replace(mobject, dim_to_match, stretch) length = mobject.length_over_dim(dim_to_match) - self.scale((length + buff) / length) + self.scale_in_place((length + buff) / length) return self - def put_start_and_end_on(self, start, end): + def put_start_and_end_on( + self: MOS, + start: Sequence[float], + end: Sequence[float], + ) -> MOS: curr_start, curr_end = self.get_start_and_end() curr_vect = curr_end - curr_start if np.all(curr_vect == 0): @@ -1664,9 +1801,10 @@ def put_start_and_end_on(self, start, end): return self # Background rectangle + # TODO(types): Don't use kwargs, so it can be typed. def add_background_rectangle( - self, color: Colors = BLACK, opacity: float = 0.75, **kwargs - ): + self: MOS, color: ColorStr = BLACK, opacity: float = 0.75, **kwargs + ) -> MOS: """Add a BackgroundRectangle as submobject. The BackgroundRectangle is added behind other submobjects. @@ -1705,19 +1843,25 @@ def add_background_rectangle( self.add_to_back(self.background_rectangle) return self - def add_background_rectangle_to_submobjects(self, **kwargs): + def add_background_rectangle_to_submobjects(self: MOS, **kwargs) -> MOS: for submobject in self.submobjects: submobject.add_background_rectangle(**kwargs) return self - def add_background_rectangle_to_family_members_with_points(self, **kwargs): + def add_background_rectangle_to_family_members_with_points( + self: MOS, **kwargs + ) -> MOS: for mob in self.family_members_with_points(): mob.add_background_rectangle(**kwargs) return self # Color functions - def set_color(self, color: Color = YELLOW_C, family: bool = True): + def set_color( + self: MOS, + color: ColorStr = YELLOW_C, + family: bool = True, + ) -> MOS: """Condition is function which takes in one arguments, (x, y, z). Here it just recurses to submobjects, but in subclasses this should be further implemented based on the the inner workings @@ -1729,17 +1873,17 @@ def set_color(self, color: Color = YELLOW_C, family: bool = True): self.color = Color(color) return self - def set_color_by_gradient(self, *colors): + def set_color_by_gradient(self: MOS, *colors: ColorStr) -> MOS: self.set_submobject_colors_by_gradient(*colors) return self def set_colors_by_radial_gradient( - self, - center=None, - radius=1, - inner_color=WHITE, - outer_color=BLACK, - ): + self: MOS, + center: np.ndarray | None = None, + radius: float = 1, + inner_color: ColorStr = WHITE, + outer_color: ColorStr = BLACK, + ) -> MOS: self.set_submobject_colors_by_radial_gradient( center, radius, @@ -1748,11 +1892,11 @@ def set_colors_by_radial_gradient( ) return self - def set_submobject_colors_by_gradient(self, *colors): + def set_submobject_colors_by_gradient(self: MOS, *colors: ColorStr) -> MOS: if len(colors) == 0: raise ValueError("Need at least one color") elif len(colors) == 1: - return self.set_color(*colors) + return self.set_color(colors[0]) mobs = self.family_members_with_points() new_colors = color_gradient(colors, len(mobs)) @@ -1762,12 +1906,12 @@ def set_submobject_colors_by_gradient(self, *colors): return self def set_submobject_colors_by_radial_gradient( - self, - center=None, - radius=1, - inner_color=WHITE, - outer_color=BLACK, - ): + self: MOS, + center: np.ndarray | None = None, + radius: float = 1, + inner_color: ColorStr = WHITE, + outer_color: ColorStr = BLACK, + ) -> MOS: if center is None: center = self.get_center() @@ -1779,11 +1923,11 @@ def set_submobject_colors_by_radial_gradient( return self - def to_original_color(self): + def to_original_color(self: MOS) -> MOS: self.set_color(self.color) return self - def fade_to(self, color, alpha, family=True): + def fade_to(self: MOS, color: ColorStr, alpha: float, family: bool = True) -> MOS: if self.get_num_points() > 0: new_color = interpolate_color(self.get_color(), color, alpha) self.set_color(new_color, family=False) @@ -1792,19 +1936,19 @@ def fade_to(self, color, alpha, family=True): submob.fade_to(color, alpha) return self - def fade(self, darkness=0.5, family=True): + def fade(self: MOS, darkness: float = 0.5, family: bool = True) -> MOS: if family: for submob in self.submobjects: submob.fade(darkness, family) return self - def get_color(self): + def get_color(self) -> ColorStr | None: """Returns the color of the :class:`~.Mobject`""" return self.color ## - def save_state(self): + def save_state(self: MOS) -> MOS: """Save the current state (position, color & size). Can be restored with :meth:`~.Mobject.restore`.""" if hasattr(self, "saved_state"): # Prevent exponential growth of data @@ -1813,7 +1957,7 @@ def save_state(self): return self - def restore(self): + def restore(self: MOS) -> MOS: """Restores the state that was previously saved with :meth:`~.Mobject.save_state`.""" if not hasattr(self, "saved_state") or self.save_state is None: raise Exception("Trying to restore without having saved") @@ -1822,7 +1966,8 @@ def restore(self): ## - def reduce_across_dimension(self, points_func, reduce_func, dim): + # TODO(types): Add type annotations. + def reduce_across_dimension(self, points_func, reduce_func, dim: int): points = self.get_all_points() if points is None or len(points) == 0: # Note, this default means things like empty VGroups @@ -1831,32 +1976,37 @@ def reduce_across_dimension(self, points_func, reduce_func, dim): values = points_func(points[:, dim]) return reduce_func(values) - def nonempty_submobjects(self): + def nonempty_submobjects(self) -> list[Mobject]: return [ submob for submob in self.submobjects if len(submob.submobjects) != 0 or len(submob.points) != 0 ] - def get_merged_array(self, array_attr): + def get_merged_array(self, array_attr: str) -> np.ndarray: result = getattr(self, array_attr) for submob in self.submobjects: result = np.append(result, submob.get_merged_array(array_attr), axis=0) submob.get_merged_array(array_attr) return result - def get_all_points(self): + def get_all_points(self) -> np.ndarray: return self.get_merged_array("points") # Getters - def get_points_defining_boundary(self): + def get_points_defining_boundary(self) -> np.ndarray: return self.get_all_points() - def get_num_points(self): + def get_num_points(self) -> int: return len(self.points) - def get_extremum_along_dim(self, points=None, dim=0, key=0): + def get_extremum_along_dim( + self, + points: np.ndarray | None = None, + dim: int = 0, + key: int = 0, + ) -> float: if points is None: points = self.get_points_defining_boundary() values = points[:, dim] @@ -1867,7 +2017,7 @@ def get_extremum_along_dim(self, points=None, dim=0, key=0): else: return np.max(values) - def get_critical_point(self, direction): + def get_critical_point(self, direction: np.ndarray) -> np.ndarray: """Picture a box bounding the :class:`~.Mobject`. Such a box has 9 'critical points': 4 corners, 4 edge center, the center. This returns one of them, along the given direction. @@ -1896,11 +2046,11 @@ def get_critical_point(self, direction): # Pseudonyms for more general get_critical_point method - def get_edge_center(self, direction) -> np.ndarray: + def get_edge_center(self, direction: np.ndarray) -> np.ndarray: """Get edge coordinates for certain direction.""" return self.get_critical_point(direction) - def get_corner(self, direction) -> np.ndarray: + def get_corner(self, direction: np.ndarray) -> np.ndarray: """Get corner coordinates for certain direction.""" return self.get_critical_point(direction) @@ -1908,10 +2058,10 @@ def get_center(self) -> np.ndarray: """Get center coordinates""" return self.get_critical_point(np.zeros(self.dim)) - def get_center_of_mass(self): + def get_center_of_mass(self) -> np.ndarray: return np.apply_along_axis(np.mean, 0, self.get_all_points()) - def get_boundary_point(self, direction): + def get_boundary_point(self, direction: np.ndarray) -> np.ndarray: all_points = self.get_points_defining_boundary() index = np.argmax(np.dot(all_points, np.array(direction).T)) return all_points[index] @@ -1963,7 +2113,7 @@ def get_nadir(self) -> np.ndarray: """Get nadir (opposite the zenith) coordinates of a box bounding a 3D :class:`~.Mobject`.""" return self.get_edge_center(IN) - def length_over_dim(self, dim): + def length_over_dim(self, dim: int) -> float: """Measure the length of an :class:`~.Mobject` in a certain direction.""" return ( self.reduce_across_dimension( @@ -1974,43 +2124,40 @@ def length_over_dim(self, dim): - self.reduce_across_dimension(np.min, np.min, dim) ) - def get_coord(self, dim, direction=ORIGIN): + def get_coord(self, dim: int, direction: np.ndarray = ORIGIN) -> float: """Meant to generalize ``get_x``, ``get_y`` and ``get_z``""" return self.get_extremum_along_dim(dim=dim, key=direction[dim]) - def get_x(self, direction=ORIGIN) -> np.float64: + def get_x(self, direction: np.ndarray = ORIGIN) -> float: """Returns x coordinate of the center of the :class:`~.Mobject` as ``float``""" return self.get_coord(0, direction) - def get_y(self, direction=ORIGIN) -> np.float64: + def get_y(self, direction: np.ndarray = ORIGIN) -> float: """Returns y coordinate of the center of the :class:`~.Mobject` as ``float``""" return self.get_coord(1, direction) - def get_z(self, direction=ORIGIN) -> np.float64: + def get_z(self, direction: np.ndarray = ORIGIN) -> float: """Returns z coordinate of the center of the :class:`~.Mobject` as ``float``""" return self.get_coord(2, direction) - def get_start(self): + def get_start(self) -> np.ndarray: """Returns the point, where the stroke that surrounds the :class:`~.Mobject` starts.""" self.throw_error_if_no_points() return np.array(self.points[0]) - def get_end(self): + def get_end(self) -> np.ndarray: """Returns the point, where the stroke that surrounds the :class:`~.Mobject` ends.""" self.throw_error_if_no_points() return np.array(self.points[-1]) - def get_start_and_end(self): + def get_start_and_end(self) -> tuple[np.ndarray, np.ndarray]: """Returns starting and ending point of a stroke as a ``tuple``.""" return self.get_start(), self.get_end() - def point_from_proportion(self, alpha): - raise NotImplementedError("Please override in a child class.") - - def proportion_from_point(self, point): + def point_from_proportion(self, alpha: float) -> np.ndarray: raise NotImplementedError("Please override in a child class.") - def get_pieces(self, n_pieces): + def get_pieces(self, n_pieces: int) -> Group: template = self.copy() template.submobjects = [] alphas = np.linspace(0, 1, n_pieces + 1) @@ -2021,7 +2168,7 @@ def get_pieces(self, n_pieces): ) ) - def get_z_index_reference_point(self): + def get_z_index_reference_point(self) -> np.ndarray: # TODO, better place to define default z_index_group? z_index_group = getattr(self, "z_index_group", self) return z_index_group.get_center() @@ -2036,27 +2183,34 @@ def has_no_points(self) -> bool: # Match other mobject properties - def match_color(self, mobject: "Mobject"): + def match_color(self: MOS, mobject: Mobject) -> MOS: """Match the color with the color of another :class:`~.Mobject`.""" return self.set_color(mobject.get_color()) - def match_dim_size(self, mobject: "Mobject", dim, **kwargs): + def match_dim_size( + self: MOS, mobject: Mobject, dim: int, **kwargs: np.ndarray | None + ) -> MOS: """Match the specified dimension with the dimension of another :class:`~.Mobject`.""" return self.rescale_to_fit(mobject.length_over_dim(dim), dim, **kwargs) - def match_width(self, mobject: "Mobject", **kwargs): + def match_width(self: MOS, mobject: Mobject, **kwargs: np.ndarray | None) -> MOS: """Match the width with the width of another :class:`~.Mobject`.""" return self.match_dim_size(mobject, 0, **kwargs) - def match_height(self, mobject: "Mobject", **kwargs): + def match_height(self: MOS, mobject: Mobject, **kwargs: np.ndarray | None) -> MOS: """Match the height with the height of another :class:`~.Mobject`.""" return self.match_dim_size(mobject, 1, **kwargs) - def match_depth(self, mobject: "Mobject", **kwargs): + def match_depth(self: MOS, mobject: Mobject, **kwargs: np.ndarray | None) -> MOS: """Match the depth with the depth of another :class:`~.Mobject`.""" return self.match_dim_size(mobject, 2, **kwargs) - def match_coord(self, mobject: "Mobject", dim, direction=ORIGIN): + def match_coord( + self: MOS, + mobject: Mobject, + dim: int, + direction: np.ndarray = ORIGIN, + ) -> MOS: """Match the coordinates with the coordinates of another :class:`~.Mobject`.""" return self.set_coord( mobject.get_coord(dim, direction), @@ -2064,24 +2218,25 @@ def match_coord(self, mobject: "Mobject", dim, direction=ORIGIN): direction=direction, ) - def match_x(self, mobject: "Mobject", direction=ORIGIN): + def match_x(self: MOS, mobject: Mobject, direction: np.ndarray = ORIGIN) -> MOS: """Match x coord. to the x coord. of another :class:`~.Mobject`.""" return self.match_coord(mobject, 0, direction) - def match_y(self, mobject: "Mobject", direction=ORIGIN): + def match_y(self: MOS, mobject: Mobject, direction: np.ndarray = ORIGIN) -> MOS: """Match y coord. to the x coord. of another :class:`~.Mobject`.""" return self.match_coord(mobject, 1, direction) - def match_z(self, mobject: "Mobject", direction=ORIGIN): + def match_z(self: MOS, mobject: Mobject, direction: np.ndarray = ORIGIN) -> MOS: """Match z coord. to the x coord. of another :class:`~.Mobject`.""" return self.match_coord(mobject, 2, direction) + # TODO: alignment_vect is unused. def align_to( - self, - mobject_or_point: Union["Mobject", np.ndarray, List], - direction=ORIGIN, - alignment_vect=UP, - ): + self: MOS, + mobject_or_point: Mobject | np.ndarray | Sequence[float], + direction: np.ndarray = ORIGIN, + alignment_vect: np.ndarray = UP, + ) -> MOS: """Aligns mobject to another :class:`~.Mobject` in a certain direction. Examples: @@ -2120,26 +2275,28 @@ def __len__(self): def get_group_class(self): return Group - def split(self): - result = [self] if len(self.points) > 0 else [] + def split(self) -> list[Mobject]: + # Annotation needed because https://github.com/microsoft/pyright/issues/2333 + result: list[Mobject] = [self] if len(self.points) > 0 else [] return result + self.submobjects - def get_family(self, recurse=True): + def get_family(self, recurse: bool = True) -> list[Mobject]: sub_families = list(map(Mobject.get_family, self.submobjects)) all_mobjects = [self] + list(it.chain(*sub_families)) + # TODO(types): remove_list_redundancies has no type annotations. return remove_list_redundancies(all_mobjects) def family_members_with_points(self): return [m for m in self.get_family() if m.get_num_points() > 0] def arrange( - self, - direction: Sequence[float] = RIGHT, - buff=DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, - center=True, + self: MOS, + direction: np.ndarray = RIGHT, + buff: float = DEFAULT_MOBJECT_TO_MOBJECT_BUFFER, + center: bool = True, **kwargs, - ): - """Sorts :class:`~.Mobject` next to each other on screen. + ) -> MOS: + """Sorts :class:`~.Mobject` - to each other on screen. Examples -------- @@ -2162,19 +2319,20 @@ def construct(self): self.center() return self + # TODO: kwargs is not used. def arrange_in_grid( - self, - rows: Optional[int] = None, - cols: Optional[int] = None, - buff: Union[float, Tuple[float, float]] = MED_SMALL_BUFF, + self: MOS, + rows: int | None = None, + cols: int | None = None, + buff: float | tuple[float, float] = MED_SMALL_BUFF, cell_alignment: np.ndarray = ORIGIN, - row_alignments: Optional[str] = None, # "ucd" - col_alignments: Optional[str] = None, # "lcr" - row_heights: Optional[Iterable[Optional[float]]] = None, - col_widths: Optional[Iterable[Optional[float]]] = None, + row_alignments: str | None = None, # "ucd" + col_alignments: str | None = None, # "lcr" + row_heights: Sequence[float | None] | None = None, + col_widths: Sequence[float | None] | None = None, flow_order: str = "rd", **kwargs, - ) -> "Mobject": + ) -> MOS: """Arrange submobjects in a grid. Parameters @@ -2268,7 +2426,11 @@ def construct(self): start_pos = self.get_center() # get cols / rows values if given (implicitly) - def init_size(num, alignments, sizes): + def init_size( + num: int | None, + alignments: str | None, + sizes: Any | None, + ) -> int | None: if num is not None: return num if alignments is not None: @@ -2287,7 +2449,7 @@ def init_size(num, alignments, sizes): # This is favored over rows>cols since in general # the sceene is wider than high. if rows is None: - rows = ceil(len(mobs) / cols) + rows = ceil(len(mobs) / cast(int, cols)) if cols is None: cols = ceil(len(mobs) / rows) if rows * cols < len(mobs): @@ -2301,25 +2463,29 @@ def init_size(num, alignments, sizes): buff_x = buff_y = buff # Initialize alignments correctly - def init_alignments(alignments, num, mapping, name, dir): + def init_alignments( + alignments: str | None, + num: int, + mapping: Mapping[str, np.ndarray], + name: str, + dir: np.ndarray, + ) -> list[np.ndarray]: if alignments is None: # Use cell_alignment as fallback return [cell_alignment * dir] * num if len(alignments) != num: raise ValueError(f"{name}_alignments has a mismatching size.") - alignments = list(alignments) - for i in range(num): - alignments[i] = mapping[alignments[i]] - return alignments + alignments_dirs = [mapping[alignments[i]] for i in range(num)] + return alignments_dirs - row_alignments = init_alignments( + row_alignments_dirs = init_alignments( row_alignments, rows, {"u": UP, "c": ORIGIN, "d": DOWN}, "row", RIGHT, ) - col_alignments = init_alignments( + col_alignments_dirs = init_alignments( col_alignments, cols, {"l": LEFT, "c": ORIGIN, "r": RIGHT}, @@ -2328,7 +2494,7 @@ def init_alignments(alignments, num, mapping, name, dir): ) # Now row_alignment[r] + col_alignment[c] is the alignment in cell [r][c] - mapper = { + mapper: dict[str, Callable[[int, int], int]] = { "dr": lambda r, c: (rows - r - 1) + c * rows, "dl": lambda r, c: (rows - r - 1) + (cols - c - 1) * rows, "ur": lambda r, c: r + c * rows, @@ -2342,18 +2508,15 @@ def init_alignments(alignments, num, mapping, name, dir): raise ValueError( 'flow_order must be one of the following values: "dr", "rd", "ld" "dl", "ru", "ur", "lu", "ul".', ) - flow_order = mapper[flow_order] + flow_order_func = mapper[flow_order] - # Reverse row_alignments and row_heights. Necessary since the + # Reverse row_alignments_dirs and row_heights. Necessary since the # grid filling is handled bottom up for simplicity reasons. - def reverse(maybe_list): - if maybe_list is not None: - maybe_list = list(maybe_list) - maybe_list.reverse() - return maybe_list + if row_heights is not None: + row_heights = list(row_heights) + row_heights.reverse() - row_alignments = reverse(row_alignments) - row_heights = reverse(row_heights) + row_alignments_dirs.reverse() placeholder = Mobject() # Used to fill up the grid temporarily, doesn't get added to the scene. @@ -2361,7 +2524,7 @@ def reverse(maybe_list): # properties of 0. mobs.extend([placeholder] * (rows * cols - len(mobs))) - grid = [[mobs[flow_order(r, c)] for c in range(cols)] for r in range(rows)] + grid = [[mobs[flow_order_func(r, c)] for c in range(cols)] for r in range(rows)] measured_heigths = [ max(grid[r][c].height for c in range(cols)) for r in range(rows) @@ -2371,9 +2534,15 @@ def reverse(maybe_list): ] # Initialize row_heights / col_widths correctly using measurements as fallback - def init_sizes(sizes, num, measures, name): + def init_sizes( + sizes: Sequence[float | None] | None, + num: int, + measures: list[float], + name: str, + ) -> list[float | None]: if sizes is None: - sizes = [None] * num + none: list[float | None] = [None] + sizes = none * num if len(sizes) != num: raise ValueError(f"{name} has a mismatching size.") return [ @@ -2388,7 +2557,7 @@ def init_sizes(sizes, num, measures, name): x = 0 for c in range(cols): if grid[r][c] is not placeholder: - alignment = row_alignments[r] + col_alignments[c] + alignment = row_alignments_dirs[r] + col_alignments_dirs[c] line = Line( x * RIGHT + y * UP, (x + widths[c]) * RIGHT + (y + heights[r]) * UP, @@ -2398,30 +2567,33 @@ def init_sizes(sizes, num, measures, name): # includes. grid[r][c].move_to(line, alignment) + # TODO(types): widths[c] and heights[r] might be None. x += widths[c] + buff_x y += heights[r] + buff_y self.move_to(start_pos) return self - def sort(self, point_to_num_func=lambda p: p[0], submob_func=None): + def sort( + self: MOS, + point_to_num_func: Callable[[np.ndarray], float] = lambda p: p[0], + submob_func: Callable[[Mobject], SupportsLessThan] | None = None, + ) -> MOS: """Sorts the list of :attr:`submobjects` by a function defined by ``submob_func``.""" if submob_func is None: - - def submob_func(m): - return point_to_num_func(m.get_center()) + submob_func = lambda m: point_to_num_func(m.get_center()) self.submobjects.sort(key=submob_func) return self - def shuffle(self, recursive=False): + def shuffle(self, recursive: bool = False) -> None: """Shuffles the list of :attr:`submobjects`.""" if recursive: for submob in self.submobjects: submob.shuffle(recursive=True) random.shuffle(self.submobjects) - def invert(self, recursive=False): + def invert(self, recursive: bool = False) -> None: """Inverts the list of :attr:`submobjects`. Parameters @@ -2445,7 +2617,7 @@ def construct(self): if recursive: for submob in self.submobjects: submob.invert(recursive=True) - list.reverse(self.submobjects) + self.submobjects.reverse() # Just here to keep from breaking old scenes. def arrange_submobjects(self, *args, **kwargs): @@ -2492,7 +2664,7 @@ def construct(self): return self.shuffle(*args, **kwargs) # Alignment - def align_data(self, mobject: "Mobject"): + def align_data(self, mobject: Mobject) -> None: self.null_point_align(mobject) self.align_submobjects(mobject) self.align_points(mobject) @@ -2500,14 +2672,14 @@ def align_data(self, mobject: "Mobject"): for m1, m2 in zip(self.submobjects, mobject.submobjects): m1.align_data(m2) - def get_point_mobject(self, center=None): + def get_point_mobject(self, center: np.ndarray | None = None) -> Mobject: """The simplest :class:`~.Mobject` to be transformed to or from self. Should by a point of the appropriate type """ msg = f"get_point_mobject not implemented for {self.__class__.__name__}" raise NotImplementedError(msg) - def align_points(self, mobject): + def align_points(self: MOS, mobject: Mobject) -> MOS: count1 = self.get_num_points() count2 = mobject.get_num_points() if count1 < count2: @@ -2516,10 +2688,10 @@ def align_points(self, mobject): mobject.align_points_with_larger(self) return self - def align_points_with_larger(self, larger_mobject): + def align_points_with_larger(self, larger_mobject: Mobject) -> None: raise NotImplementedError("Please override in a child class.") - def align_submobjects(self, mobject): + def align_submobjects(self: MOS, mobject: Mobject) -> MOS: mob1 = self mob2 = mobject n1 = len(mob1.submobjects) @@ -2528,7 +2700,7 @@ def align_submobjects(self, mobject): mob2.add_n_more_submobjects(max(0, n1 - n2)) return self - def null_point_align(self, mobject: "Mobject") -> "Mobject": + def null_point_align(self: MOS, mobject: Mobject) -> MOS: """If a :class:`~.Mobject` with points is being aligned to one without, treat both as groups, and push the one with points into its own submobjects @@ -2539,29 +2711,29 @@ def null_point_align(self, mobject: "Mobject") -> "Mobject": m2.push_self_into_submobjects() return self - def push_self_into_submobjects(self): + def push_self_into_submobjects(self: MOS) -> MOS: copy = self.copy() copy.submobjects = [] self.reset_points() self.add(copy) return self - def add_n_more_submobjects(self, n): + def add_n_more_submobjects(self: MOS, n: int) -> MOS: if n == 0: - return + return self curr = len(self.submobjects) if curr == 0: # If empty, simply add n point mobjects - self.submobjects = [self.get_point_mobject() for k in range(n)] - return + self.submobjects = [self.get_point_mobject() for _ in range(n)] + return self target = curr + n # TODO, factor this out to utils so as to reuse # with VMobject.insert_n_curves repeat_indices = (np.arange(target) * curr) // target split_factors = [sum(repeat_indices == i) for i in range(curr)] - new_submobs = [] + new_submobs: list[Mobject] = [] for submob, sf in zip(self.submobjects, split_factors): new_submobs.append(submob) for _ in range(1, sf): @@ -2569,10 +2741,19 @@ def add_n_more_submobjects(self, n): self.submobjects = new_submobs return self - def repeat_submobject(self, submob): + def repeat_submobject(self, submob: MOS) -> MOS: return submob.copy() - def interpolate(self, mobject1, mobject2, alpha, path_func=straight_path): + def interpolate( + self: MOS, + mobject1: Mobject, + mobject2: Mobject, + alpha: float, + path_func: Callable[ + [Interpolable, Interpolable, float], + Interpolable, + ] = straight_path, + ) -> MOS: """Turns this :class:`~.Mobject` into an interpolation between ``mobject1`` and ``mobject2``. @@ -2597,10 +2778,16 @@ def construct(self): self.interpolate_color(mobject1, mobject2, alpha) return self - def interpolate_color(self, mobject1, mobject2, alpha): + # TODO: Subclasses are inconsistent about returning self or None. + def interpolate_color( + self, + mobject1: Mobject, + mobject2: Mobject, + alpha: float, + ) -> None: raise NotImplementedError("Please override in a child class.") - def become(self, mobject: "Mobject", copy_submobjects: bool = True): + def become(self: MOS, mobject: Mobject, copy_submobjects: bool = True) -> MOS: """Edit points, colors and submobjects to be identical to another :class:`~.Mobject` @@ -2623,7 +2810,7 @@ def construct(self): sm1.interpolate_color(sm1, sm2, 1) return self - def match_points(self, mobject: "Mobject", copy_submobjects: bool = True): + def match_points(self: MOS, mobject: Mobject, copy_submobjects: bool = True) -> MOS: """Edit points, positions, and submobjects to be identical to another :class:`~.Mobject`, while keeping the style unchanged. @@ -2646,7 +2833,7 @@ def construct(self): return self # Errors - def throw_error_if_no_points(self): + def throw_error_if_no_points(self) -> None: if self.has_no_points(): caller_name = sys._getframe(1).f_code.co_name raise Exception( @@ -2655,10 +2842,10 @@ def throw_error_if_no_points(self): # About z-index def set_z_index( - self, + self: MOS, z_index_value: float, family: bool = True, - ) -> "VMobject": + ) -> MOS: """Sets the :class:`~.Mobject`'s :attr:`z_index` to the value specified in `z_index_value`. Parameters @@ -2697,7 +2884,7 @@ def construct(self): self.z_index = z_index_value return self - def set_z_index_by_z_coordinate(self): + def set_z_index_by_z_coordinate(self: MOS) -> MOS: """Sets the :class:`~.Mobject`'s z coordinate to the value of :attr:`z_index`. Returns @@ -2713,13 +2900,20 @@ def set_z_index_by_z_coordinate(self): class Group(Mobject, metaclass=ConvertToOpenGL): """Groups together multiple :class:`Mobjects <.Mobject>`.""" - def __init__(self, *mobjects, **kwargs): + def __init__(self, *mobjects: Mobject, **kwargs): super().__init__(**kwargs) self.add(*mobjects) class _AnimationBuilder: - def __init__(self, mobject): + mobject: Mobject + overriden_animation: Any | None # TODO + is_chaining: bool + methods: list[Callable[..., Any]] # TODO + cannot_pass_args: bool + anim_args: dict[str, Any] + + def __init__(self, mobject: Mobject): self.mobject = mobject self.mobject.generate_target() @@ -2742,7 +2936,7 @@ def __call__(self, **kwargs): return self - def __getattr__(self, method_name): + def __getattr__(self, method_name: str): method = getattr(self.mobject.target, method_name) self.methods.append(method) has_overridden_animation = hasattr(method, "_override_animate") @@ -2784,7 +2978,13 @@ def build(self): return anim -def override_animate(method): +# Python's static type hinting is very weak for Callable. `method` is passed +# an Mobject and some other parameters (see where _override_animate() is called). +# The second and third ...'s should be a generic parameter but there's no +# syntax for that. +def override_animate( + method: Callable[..., Animation], +) -> Callable[[Callable[..., T]], Callable[..., T]]: r"""Decorator for overriding method animations. This allows to specify a method (returning an :class:`~.Animation`) @@ -2836,7 +3036,7 @@ def construct(self): """ - def decorator(animation_method): + def decorator(animation_method: Callable[..., T]) -> Callable[..., T]: method._override_animate = animation_method return animation_method diff --git a/manim/utils/bezier.py b/manim/utils/bezier.py index 800737a256..1d05c3ee58 100644 --- a/manim/utils/bezier.py +++ b/manim/utils/bezier.py @@ -22,6 +22,7 @@ from functools import reduce import numpy as np +from _typeshed import SupportsLessThanT from scipy import linalg from ..utils.simple_functions import choose @@ -102,14 +103,16 @@ def curve(t): # Linear interpolation variants +Interpolable = typing.TypeVar("Interpolable", float, np.ndarray) -def interpolate(start: int, end: int, alpha: float) -> float: + +def interpolate(start: Interpolable, end: Interpolable, alpha: float) -> Interpolable: return (1 - alpha) * start + alpha * end def integer_interpolate( - start: float, - end: float, + start: int, + end: int, alpha: float, ) -> typing.Tuple[int, float]: """ diff --git a/manim/utils/deprecation.py b/manim/utils/deprecation.py index 8509cc1b24..cfce1e9e2a 100644 --- a/manim/utils/deprecation.py +++ b/manim/utils/deprecation.py @@ -5,12 +5,14 @@ import inspect import re -from typing import Any, Callable, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union from decorator import decorate, decorator from .. import logger +T = TypeVar("T", bound=Callable[..., Any]) + def _get_callable_info(callable: Callable) -> Tuple[str, str]: """Returns type and name of a callable. @@ -67,12 +69,12 @@ def _deprecation_text_component( def deprecated( - func: Callable = None, + func: Optional[T] = None, since: Optional[str] = None, until: Optional[str] = None, replacement: Optional[str] = None, message: Optional[str] = "", -) -> Callable: +) -> T: """Decorator to mark a callable as deprecated. The decorated callable will cause a warning when used. The docstring of the diff --git a/manim/utils/space_ops.py b/manim/utils/space_ops.py index f9bc29c789..a06c8cdafe 100644 --- a/manim/utils/space_ops.py +++ b/manim/utils/space_ops.py @@ -341,7 +341,7 @@ def angle_of_vector(vector: Sequence[float]) -> float: return np.angle(complex(*vector[:2])) -def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> np.ndarray: +def angle_between_vectors(v1: np.ndarray, v2: np.ndarray) -> float: """Returns the angle between two vectors. This angle will always be between 0 and pi