From 3a9e0bca2399a2c19e3bbca04f5adfeec99051ac Mon Sep 17 00:00:00 2001 From: Chadwick Stryker Date: Sun, 21 Jan 2024 12:16:27 -0800 Subject: [PATCH] Ported the ProfiledPIDSubsystem from the wpilib java source to Python Co-authored-by: Dustin Spicuzza --- commands2/__init__.py | 2 + commands2/profiledpidsubsystem.py | 78 +++++++++++++++++++ tests/test_profiledpidsubsystem.py | 119 +++++++++++++++++++++++++++++ 3 files changed, 199 insertions(+) create mode 100644 commands2/profiledpidsubsystem.py create mode 100644 tests/test_profiledpidsubsystem.py diff --git a/commands2/__init__.py b/commands2/__init__.py index 445638be..01efe61a 100644 --- a/commands2/__init__.py +++ b/commands2/__init__.py @@ -16,6 +16,7 @@ from .pidcommand import PIDCommand from .pidsubsystem import PIDSubsystem from .printcommand import PrintCommand +from .profiledpidsubsystem import ProfiledPIDSubsystem from .proxycommand import ProxyCommand from .repeatcommand import RepeatCommand from .runcommand import RunCommand @@ -51,6 +52,7 @@ "PIDCommand", "PIDSubsystem", "PrintCommand", + "ProfiledPIDSubsystem", "ProxyCommand", "RepeatCommand", "RunCommand", diff --git a/commands2/profiledpidsubsystem.py b/commands2/profiledpidsubsystem.py new file mode 100644 index 00000000..f2c7069e --- /dev/null +++ b/commands2/profiledpidsubsystem.py @@ -0,0 +1,78 @@ +# Copyright (c) FIRST and other WPILib contributors. +# Open Source Software; you can modify and/or share it under the terms of +# the WPILib BSD license file in the root directory of this project. + +from typing import Union, cast + +from wpimath.trajectory import TrapezoidProfile + +from .subsystem import Subsystem + + +class ProfiledPIDSubsystem(Subsystem): + """ + A subsystem that uses a :class:`wpimath.controller.ProfiledPIDController` + or :class:`wpimath.controller.ProfiledPIDControllerRadians` to + control an output. The controller is run synchronously from the subsystem's + :meth:`.periodic` method. + """ + + def __init__( + self, + controller, + initial_position: float = 0, + ): + """Creates a new PIDSubsystem.""" + super().__init__() + self._controller = controller + self._enabled = False + self.setGoal(initial_position) + + def periodic(self): + """Updates the output of the controller.""" + if self._enabled: + self.useOutput( + self._controller.calculate(self.getMeasurement()), + self._controller.getSetpoint(), + ) + + def getController( + self, + ): + """Returns the controller""" + return self._controller + + def setGoal(self, goal): + """ + Sets the goal state for the subsystem. + """ + self._controller.setGoal(goal) + + def useOutput(self, output: float, setpoint: TrapezoidProfile.State): + """ + Uses the output from the controller object. + """ + raise NotImplementedError(f"{self.__class__} must implement useOutput") + + def getMeasurement(self) -> float: + """ + Returns the measurement of the process variable used by the + controller object. + """ + raise NotImplementedError(f"{self.__class__} must implement getMeasurement") + + def enable(self): + """Enables the PID control. Resets the controller.""" + self._enabled = True + self._controller.reset(self.getMeasurement()) + + def disable(self): + """Disables the PID control. Sets output to zero.""" + self._enabled = False + self.useOutput(0, TrapezoidProfile.State()) + + def isEnabled(self) -> bool: + """ + Returns whether the controller is enabled. + """ + return self._enabled diff --git a/tests/test_profiledpidsubsystem.py b/tests/test_profiledpidsubsystem.py new file mode 100644 index 00000000..8b896f3c --- /dev/null +++ b/tests/test_profiledpidsubsystem.py @@ -0,0 +1,119 @@ +from types import MethodType +from typing import Any + +import pytest +from wpimath.controller import ProfiledPIDController, ProfiledPIDControllerRadians +from wpimath.trajectory import TrapezoidProfile, TrapezoidProfileRadians + +from commands2 import ProfiledPIDSubsystem + +MAX_VELOCITY = 30 # Radians per second +MAX_ACCELERATION = 500 # Radians per sec squared +PID_KP = 50 + + +class EvalSubsystem(ProfiledPIDSubsystem): + def __init__(self, controller, state_factory): + self._state_factory = state_factory + super().__init__(controller, 0) + + +def simple_use_output(self, output: float, setpoint: Any): + """A simple useOutput method that saves the current state of the controller.""" + self._output = output + self._setpoint = setpoint + + +def simple_get_measurement(self) -> float: + """A simple getMeasurement method that returns zero (frozen or stuck plant).""" + return 0.0 + + +controller_types = [ + ( + ProfiledPIDControllerRadians, + TrapezoidProfileRadians.Constraints, + TrapezoidProfileRadians.State, + ), + (ProfiledPIDController, TrapezoidProfile.Constraints, TrapezoidProfile.State), +] +controller_ids = ["radians", "dimensionless"] + + +@pytest.fixture(params=controller_types, ids=controller_ids) +def subsystem(request): + """ + Fixture that returns an EvalSubsystem object for each type of controller. + """ + controller, profile_factory, state_factory = request.param + profile = profile_factory(MAX_VELOCITY, MAX_ACCELERATION) + pid = controller(PID_KP, 0, 0, profile) + return EvalSubsystem(pid, state_factory) + + +def test_profiled_pid_subsystem_init(subsystem): + """ + Verify that the ProfiledPIDSubsystem can be initialized using + all supported profiled PID controller / trapezoid profile types. + """ + assert isinstance(subsystem, EvalSubsystem) + + +def test_profiled_pid_subsystem_not_implemented_get_measurement(subsystem): + """ + Verify that the ProfiledPIDSubsystem.getMeasurement method + raises NotImplementedError. + """ + with pytest.raises(NotImplementedError): + subsystem.getMeasurement() + + +def test_profiled_pid_subsystem_not_implemented_use_output(subsystem): + """ + Verify that the ProfiledPIDSubsystem.useOutput method raises + NotImplementedError. + """ + with pytest.raises(NotImplementedError): + subsystem.useOutput(0, subsystem._state_factory()) + + +@pytest.mark.parametrize("use_float", [True, False]) +def test_profiled_pid_subsystem_set_goal(subsystem, use_float): + """ + Verify that the ProfiledPIDSubsystem.setGoal method sets the goal. + """ + if use_float: + subsystem.setGoal(1.0) + assert subsystem.getController().getGoal().position == 1.0 + assert subsystem.getController().getGoal().velocity == 0.0 + else: + subsystem.setGoal(subsystem._state_factory(1.0, 2.0)) + assert subsystem.getController().getGoal().position == 1.0 + assert subsystem.getController().getGoal().velocity == 2.0 + + +def test_profiled_pid_subsystem_enable_subsystem(subsystem): + """ + Verify the subsystem can be enabled. + """ + # Dynamically add useOutput and getMeasurement methods so the + # system can be enabled + setattr(subsystem, "useOutput", MethodType(simple_use_output, subsystem)) + setattr(subsystem, "getMeasurement", MethodType(simple_get_measurement, subsystem)) + # Enable the subsystem + subsystem.enable() + assert subsystem.isEnabled() + + +def test_profiled_pid_subsystem_disable_subsystem(subsystem): + """ + Verify the subsystem can be disabled. + """ + # Dynamically add useOutput and getMeasurement methods so the + # system can be enabled + setattr(subsystem, "useOutput", MethodType(simple_use_output, subsystem)) + setattr(subsystem, "getMeasurement", MethodType(simple_get_measurement, subsystem)) + # Enable and then disable the subsystem + subsystem.enable() + subsystem.disable() + assert not subsystem.isEnabled()