Skip to content

Add @ shorthand for CoordinateSystem methods coords_to_point (c2p) and point_to_coords (p2c) #3754

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
May 27, 2024
24 changes: 18 additions & 6 deletions manim/mobject/graphing/coordinate_systems.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from manim.mobject.graphing.functions import ImplicitFunction, ParametricFunction
from manim.mobject.graphing.number_line import NumberLine
from manim.mobject.graphing.scale import LinearBase
from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_compatibility import ConvertToOpenGL
from manim.mobject.opengl.opengl_surface import OpenGLSurface
from manim.mobject.text.tex_mobject import MathTex
Expand Down Expand Up @@ -96,10 +97,10 @@ def construct(self):
)

# Extra lines and labels for point (1,1)
graphs += grid.get_horizontal_line(grid.c2p(1, 1, 0), color=BLUE)
graphs += grid.get_vertical_line(grid.c2p(1, 1, 0), color=BLUE)
graphs += Dot(point=grid.c2p(1, 1, 0), color=YELLOW)
graphs += Tex("(1,1)").scale(0.75).next_to(grid.c2p(1, 1, 0))
graphs += grid.get_horizontal_line(grid @ (1, 1, 0), color=BLUE)
graphs += grid.get_vertical_line(grid @ (1, 1, 0), color=BLUE)
graphs += Dot(point=grid @ (1, 1, 0), color=YELLOW)
graphs += Tex("(1,1)").scale(0.75).next_to(grid @ (1, 1, 0))
title = Title(
# spaces between braces to prevent SyntaxError
r"Graphs of $y=x^{ {1}\over{n} }$ and $y=x^n (n=1,2,3,...,20)$",
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(
self.y_length = y_length
self.num_sampled_graph_points_per_tick = 10

def coords_to_point(self, *coords: Sequence[ManimFloat]):
def coords_to_point(self, *coords: ManimFloat):
raise NotImplementedError()

def point_to_coords(self, point: Point3D):
Expand Down Expand Up @@ -570,7 +571,7 @@ def get_horizontal_line(self, point: Sequence[float], **kwargs) -> Line:
class GetHorizontalLineExample(Scene):
def construct(self):
ax = Axes().add_coordinates()
point = ax.c2p(-4, 1.5)
point = ax @ (-4, 1.5)

dot = Dot(point)
line = ax.get_horizontal_line(point, line_func=Line)
Expand Down Expand Up @@ -1790,6 +1791,14 @@ def construct(self):

return T_label_group

def __matmul__(self, coord: Point3D | Mobject):
if isinstance(coord, Mobject):
coord = coord.get_center()
return self.coords_to_point(*coord)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return self.coords_to_point(*coord)
return self.coords_to_point(coord)

?
Probably a test would be nice

Copy link
Member Author

@JasonGrace2282 JasonGrace2282 May 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem with removing the * is that ax @ (1, 0, 0) then returns a 3x3 array, which I felt was kinda unintuitive.
As for the test, I already have a doctest in Axes. Is another test really necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the interest of getting this merged, I went ahead and added a test for the behavior.


def __rmatmul__(self, point: Point3D):
return self.point_to_coords(point)


class Axes(VGroup, CoordinateSystem, metaclass=ConvertToOpenGL):
"""Creates a set of axes.
Expand Down Expand Up @@ -1990,6 +1999,7 @@ def coords_to_point(
self, *coords: float | Sequence[float] | Sequence[Sequence[float]] | np.ndarray
) -> np.ndarray:
"""Accepts coordinates from the axes and returns a point with respect to the scene.
Equivalent to `ax @ (coord1)`

Parameters
----------
Expand Down Expand Up @@ -2018,6 +2028,8 @@ def coords_to_point(
>>> ax = Axes()
>>> np.around(ax.coords_to_point(1, 0, 0), 2)
array([0.86, 0. , 0. ])
>>> np.around(ax @ (1, 0, 0), 2)
array([0.86, 0. , 0. ])
>>> np.around(ax.coords_to_point([[0, 1], [1, 1], [1, 0]]), 2)
array([[0. , 0.75, 0. ],
[0.86, 0.75, 0. ],
Expand Down
13 changes: 13 additions & 0 deletions manim/mobject/graphing/number_line.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from manim.mobject.mobject import Mobject
from manim.mobject.opengl.opengl_vectorized_mobject import OpenGLVMobject

__all__ = ["NumberLine", "UnitInterval"]
Expand All @@ -12,6 +13,7 @@

if TYPE_CHECKING:
from manim.mobject.geometry.tips import ArrowTip
from manim.typing import Point3D

import numpy as np

Expand Down Expand Up @@ -344,6 +346,7 @@ def get_tick_range(self) -> np.ndarray:
def number_to_point(self, number: float | np.ndarray) -> np.ndarray:
"""Accepts a value along the number line and returns a point with
respect to the scene.
Equivalent to `NumberLine @ number`

Parameters
----------
Expand All @@ -364,6 +367,8 @@ def number_to_point(self, number: float | np.ndarray) -> np.ndarray:
array([0., 0., 0.])
>>> number_line.number_to_point(1)
array([1., 0., 0.])
>>> number_line @ 1
array([1., 0., 0.])
>>> number_line.number_to_point([1, 2, 3])
array([[1., 0., 0.],
[2., 0., 0.],
Expand Down Expand Up @@ -642,6 +647,14 @@ def _decimal_places_from_step(step) -> int:
return 0
return len(step.split(".")[-1])

def __matmul__(self, other: float):
return self.n2p(other)

def __rmatmul__(self, other: Point3D | Mobject):
if isinstance(other, Mobject):
other = other.get_center()
return self.p2n(other)


class UnitInterval(NumberLine):
def __init__(
Expand Down
12 changes: 10 additions & 2 deletions tests/module/mobject/graphing/test_coordinate_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from manim import LEFT, ORIGIN, PI, UR, Axes, Circle, ComplexPlane
from manim import CoordinateSystem as CS
from manim import NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig
from manim import Dot, NumberPlane, PolarPlane, ThreeDAxes, config, tempconfig


def test_initial_config():
Expand Down Expand Up @@ -119,7 +119,15 @@ def test_coords_to_point():

# a point with respect to the axes
c2p_coord = np.around(ax.coords_to_point(2, 2), decimals=4)
np.testing.assert_array_equal(c2p_coord, (1.7143, 1.5, 0))
c2p_coord_matmul = np.around(ax @ (2, 2), decimals=4)

expected = (1.7143, 1.5, 0)

np.testing.assert_array_equal(c2p_coord, expected)
np.testing.assert_array_equal(c2p_coord_matmul, c2p_coord)

mob = Dot().move_to((2, 2, 0))
np.testing.assert_array_equal(np.around(ax @ mob, decimals=4), expected)


def test_coords_to_point_vectorized():
Expand Down
Loading