diff --git a/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandBase.java b/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandBase.java index 619f3301be..5b577bd3cc 100644 --- a/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandBase.java +++ b/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandBase.java @@ -4,6 +4,8 @@ package edu.wpi.first.wpilibj2.command; +import static edu.wpi.first.util.ErrorMessages.requireNonNullParam; + import edu.wpi.first.util.sendable.Sendable; import edu.wpi.first.util.sendable.SendableBuilder; import edu.wpi.first.util.sendable.SendableRegistry; @@ -29,7 +31,9 @@ public abstract class CommandBase implements Sendable, Command { * @param requirements the requirements to add */ public final void addRequirements(Subsystem... requirements) { - m_requirements.addAll(Set.of(requirements)); + for (Subsystem requirement : requirements) { + m_requirements.add(requireNonNullParam(requirement, "requirement", "addRequirements")); + } } @Override diff --git a/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandScheduler.java b/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandScheduler.java index 4bd927f7f6..f68f0ce49b 100644 --- a/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandScheduler.java +++ b/wpilibNewCommands/src/main/java/edu/wpi/first/wpilibj2/command/CommandScheduler.java @@ -13,6 +13,7 @@ import edu.wpi.first.networktables.NTSendable; import edu.wpi.first.networktables.NTSendableBuilder; import edu.wpi.first.networktables.NetworkTableEntry; import edu.wpi.first.util.sendable.SendableRegistry; +import edu.wpi.first.wpilibj.DriverStation; import edu.wpi.first.wpilibj.RobotBase; import edu.wpi.first.wpilibj.RobotState; import edu.wpi.first.wpilibj.TimedRobot; @@ -26,6 +27,7 @@ import java.util.Iterator; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.function.Consumer; @@ -153,7 +155,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { */ @Deprecated(since = "2023") public void addButton(Runnable button) { - m_activeButtonLoop.bind(() -> true, button); + m_activeButtonLoop.bind(() -> true, requireNonNullParam(button, "button", "addButton")); } /** @@ -176,10 +178,10 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { private void initCommand(Command command, boolean interruptible, Set requirements) { CommandState scheduledCommand = new CommandState(interruptible); m_scheduledCommands.put(command, scheduledCommand); - command.initialize(); for (Subsystem requirement : requirements) { m_requirements.put(requirement, command); } + command.initialize(); for (Consumer action : m_initActions) { action.accept(command); } @@ -193,10 +195,15 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * using those requirements have been scheduled as interruptible. If this is the case, they will * be interrupted and the command will be scheduled. * - * @param interruptible whether this command can be interrupted - * @param command the command to schedule + * @param interruptible whether this command can be interrupted. + * @param command the command to schedule. If null, no-op. */ private void schedule(boolean interruptible, Command command) { + if (command == null) { + DriverStation.reportWarning("Tried to schedule a null command", true); + return; + } + if (m_inRunLoop) { m_toSchedule.put(command, interruptible); return; @@ -211,7 +218,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { // run when disabled, or the command is already scheduled. if (m_disabled || RobotState.isDisabled() && !command.runsWhenDisabled() - || m_scheduledCommands.containsKey(command)) { + || isScheduled(command)) { return; } @@ -224,14 +231,18 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { // Else check if the requirements that are in use have all have interruptible commands, // and if so, interrupt those commands and schedule the new command. for (Subsystem requirement : requirements) { - if (m_requirements.containsKey(requirement) - && !m_scheduledCommands.get(m_requirements.get(requirement)).isInterruptible()) { + Command requiring = requiring(requirement); + if (requiring != null + && !Optional.ofNullable(m_scheduledCommands.get(requiring)) + .map(CommandState::isInterruptible) + .orElse(true)) { return; } } for (Subsystem requirement : requirements) { - if (m_requirements.containsKey(requirement)) { - cancel(m_requirements.get(requirement)); + Command requiring = requiring(requirement); + if (requiring != null) { + cancel(requiring); } } initCommand(command, interruptible, requirements); @@ -245,7 +256,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * they will be interrupted and the command will be scheduled. * * @param interruptible whether the commands should be interruptible - * @param commands the commands to schedule + * @param commands the commands to schedule. No-op if null. */ public void schedule(boolean interruptible, Command... commands) { for (Command command : commands) { @@ -257,7 +268,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * Schedules multiple commands for execution, with interruptible defaulted to true. Does nothing * if the command is already scheduled. * - * @param commands the commands to schedule + * @param commands the commands to schedule. No-op on null. */ public void schedule(Command... commands) { schedule(true, commands); @@ -370,6 +381,10 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { */ public void registerSubsystem(Subsystem... subsystems) { for (Subsystem subsystem : subsystems) { + if (subsystem == null) { + DriverStation.reportWarning("Tried to register a null subsystem", true); + continue; + } m_subsystems.put(subsystem, null); } } @@ -395,6 +410,15 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * @param defaultCommand the default command to associate with the subsystem */ public void setDefaultCommand(Subsystem subsystem, Command defaultCommand) { + if (subsystem == null) { + DriverStation.reportWarning("Tried to set a default command for a null subsystem", true); + return; + } + if (defaultCommand == null) { + DriverStation.reportWarning("Tried to set a null default command", true); + return; + } + if (!defaultCommand.getRequirements().contains(subsystem)) { throw new IllegalArgumentException("Default commands must require their subsystem!"); } @@ -433,16 +457,20 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { } for (Command command : commands) { - if (!m_scheduledCommands.containsKey(command)) { + if (command == null) { + DriverStation.reportWarning("Tried to cancel a null command", true); + continue; + } + if (!isScheduled(command)) { continue; } + m_scheduledCommands.remove(command); + m_requirements.keySet().removeAll(command.getRequirements()); command.end(true); for (Consumer action : m_interruptActions) { action.accept(command); } - m_scheduledCommands.remove(command); - m_requirements.keySet().removeAll(command.getRequirements()); m_watchdog.addEpoch(command.getName() + ".end(true)"); } } @@ -511,7 +539,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * @param action the action to perform */ public void onCommandInitialize(Consumer action) { - m_initActions.add(action); + m_initActions.add(requireNonNullParam(action, "action", "onCommandInitialize")); } /** @@ -520,7 +548,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * @param action the action to perform */ public void onCommandExecute(Consumer action) { - m_executeActions.add(action); + m_executeActions.add(requireNonNullParam(action, "action", "onCommandExecute")); } /** @@ -529,7 +557,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * @param action the action to perform */ public void onCommandInterrupt(Consumer action) { - m_interruptActions.add(action); + m_interruptActions.add(requireNonNullParam(action, "action", "onCommandInterrupt")); } /** @@ -538,7 +566,7 @@ public final class CommandScheduler implements NTSendable, AutoCloseable { * @param action the action to perform */ public void onCommandFinish(Consumer action) { - m_finishActions.add(action); + m_finishActions.add(requireNonNullParam(action, "action", "onCommandFinish")); } @Override diff --git a/wpilibNewCommands/src/main/native/cpp/frc2/command/CommandScheduler.cpp b/wpilibNewCommands/src/main/native/cpp/frc2/command/CommandScheduler.cpp index 1a1bd699fd..3f538792d7 100644 --- a/wpilibNewCommands/src/main/native/cpp/frc2/command/CommandScheduler.cpp +++ b/wpilibNewCommands/src/main/native/cpp/frc2/command/CommandScheduler.cpp @@ -151,11 +151,11 @@ void CommandScheduler::Schedule(bool interruptible, Command* command) { Cancel(cmdToCancel); } } - command->Initialize(); m_impl->scheduledCommands[command] = CommandState{interruptible}; for (auto&& requirement : requirements) { m_impl->requirements[requirement] = command; } + command->Initialize(); for (auto&& action : m_impl->initActions) { action(*command); } @@ -336,17 +336,17 @@ void CommandScheduler::Cancel(Command* command) { if (find == m_impl->scheduledCommands.end()) { return; } - command->End(true); - for (auto&& action : m_impl->interruptActions) { - action(*command); - } - m_watchdog.AddEpoch(command->GetName() + ".End(true)"); m_impl->scheduledCommands.erase(find); for (auto&& requirement : m_impl->requirements) { if (requirement.second == command) { m_impl->requirements.erase(requirement.first); } } + command->End(true); + for (auto&& action : m_impl->interruptActions) { + action(*command); + } + m_watchdog.AddEpoch(command->GetName() + ".End(true)"); } void CommandScheduler::Cancel(wpi::span commands) { diff --git a/wpilibNewCommands/src/main/native/include/frc2/command/CommandScheduler.h b/wpilibNewCommands/src/main/native/include/frc2/command/CommandScheduler.h index d6458bb3c4..63ddb81f8d 100644 --- a/wpilibNewCommands/src/main/native/include/frc2/command/CommandScheduler.h +++ b/wpilibNewCommands/src/main/native/include/frc2/command/CommandScheduler.h @@ -367,5 +367,8 @@ class CommandScheduler final : public nt::NTSendable, frc::Watchdog m_watchdog; friend class CommandTestBase; + + template + friend class CommandTestBaseWithParam; }; } // namespace frc2 diff --git a/wpilibNewCommands/src/test/java/edu/wpi/first/wpilibj2/command/SchedulingRecursionTest.java b/wpilibNewCommands/src/test/java/edu/wpi/first/wpilibj2/command/SchedulingRecursionTest.java new file mode 100644 index 0000000000..6e6e5c9887 --- /dev/null +++ b/wpilibNewCommands/src/test/java/edu/wpi/first/wpilibj2/command/SchedulingRecursionTest.java @@ -0,0 +1,193 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package edu.wpi.first.wpilibj2.command; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +class SchedulingRecursionTest extends CommandTestBase { + /** + * wpilibsuite/allwpilib#4259. + */ + @ValueSource(booleans = {true, false}) + @ParameterizedTest + void cancelFromInitialize(boolean interruptible) { + try (CommandScheduler scheduler = new CommandScheduler()) { + AtomicBoolean hasOtherRun = new AtomicBoolean(); + Subsystem requirement = new SubsystemBase() {}; + Command selfCancels = + new CommandBase() { + { + addRequirements(requirement); + } + + @Override + public void initialize() { + scheduler.cancel(this); + } + }; + Command other = new RunCommand(() -> hasOtherRun.set(true), requirement); + + assertDoesNotThrow( + () -> { + scheduler.schedule(interruptible, selfCancels); + scheduler.run(); + // interruptibility of new arrival isn't checked + scheduler.schedule(other); + }); + assertFalse(scheduler.isScheduled(selfCancels)); + assertTrue(scheduler.isScheduled(other)); + scheduler.run(); + assertTrue(hasOtherRun.get()); + } + } + + @ValueSource(booleans = {true, false}) + @ParameterizedTest + void defaultCommand(boolean interruptible) { + try (CommandScheduler scheduler = new CommandScheduler()) { + AtomicBoolean hasOtherRun = new AtomicBoolean(); + Subsystem requirement = new SubsystemBase() {}; + Command selfCancels = + new CommandBase() { + { + addRequirements(requirement); + } + + @Override + public void initialize() { + scheduler.cancel(this); + } + }; + Command other = new RunCommand(() -> hasOtherRun.set(true), requirement); + scheduler.setDefaultCommand(requirement, other); + + assertDoesNotThrow( + () -> { + scheduler.schedule(interruptible, selfCancels); + scheduler.run(); + }); + scheduler.run(); + assertFalse(scheduler.isScheduled(selfCancels)); + assertTrue(scheduler.isScheduled(other)); + scheduler.run(); + assertTrue(hasOtherRun.get()); + } + } + + @Test + void cancelFromEnd() { + try (CommandScheduler scheduler = new CommandScheduler()) { + AtomicInteger counter = new AtomicInteger(); + Command selfCancels = + new CommandBase() { + @Override + public void end(boolean interrupted) { + counter.incrementAndGet(); + scheduler.cancel(this); + } + }; + scheduler.schedule(selfCancels); + + assertDoesNotThrow(() -> scheduler.cancel(selfCancels)); + assertEquals(1, counter.get()); + assertFalse(scheduler.isScheduled(selfCancels)); + } + } + + @Test + void scheduleFromEndCancel() { + try (CommandScheduler scheduler = new CommandScheduler()) { + AtomicInteger counter = new AtomicInteger(); + Subsystem requirement = new SubsystemBase() {}; + InstantCommand other = new InstantCommand(() -> {}, requirement); + Command selfCancels = + new CommandBase() { + { + addRequirements(requirement); + } + + @Override + public void end(boolean interrupted) { + counter.incrementAndGet(); + scheduler.schedule(other); + } + }; + + scheduler.schedule(selfCancels); + + assertDoesNotThrow(() -> scheduler.cancel(selfCancels)); + assertEquals(1, counter.get()); + assertFalse(scheduler.isScheduled(selfCancels)); + } + } + + @Test + void scheduleFromEndInterrupt() { + try (CommandScheduler scheduler = new CommandScheduler()) { + AtomicInteger counter = new AtomicInteger(); + Subsystem requirement = new SubsystemBase() {}; + InstantCommand other = new InstantCommand(() -> {}, requirement); + Command selfCancels = + new CommandBase() { + { + addRequirements(requirement); + } + + @Override + public void end(boolean interrupted) { + counter.incrementAndGet(); + scheduler.schedule(other); + } + }; + + scheduler.schedule(selfCancels); + + assertDoesNotThrow(() -> scheduler.schedule(other)); + assertEquals(1, counter.get()); + assertFalse(scheduler.isScheduled(selfCancels)); + assertTrue(scheduler.isScheduled(other)); + } + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void scheduleInitializeFromDefaultCommand(boolean interruptible) { + try (CommandScheduler scheduler = new CommandScheduler()) { + AtomicInteger counter = new AtomicInteger(); + Subsystem requirement = new SubsystemBase() {}; + Command other = new InstantCommand(() -> {}, requirement); + Command defaultCommand = + new CommandBase() { + { + addRequirements(requirement); + } + + @Override + public void initialize() { + counter.incrementAndGet(); + scheduler.schedule(interruptible, other); + } + }; + + scheduler.setDefaultCommand(requirement, defaultCommand); + + scheduler.run(); + scheduler.run(); + scheduler.run(); + assertEquals(3, counter.get()); + assertFalse(scheduler.isScheduled(defaultCommand)); + assertTrue(scheduler.isScheduled(other)); + } + } +} diff --git a/wpilibNewCommands/src/test/native/cpp/frc2/command/CommandTestBase.h b/wpilibNewCommands/src/test/native/cpp/frc2/command/CommandTestBase.h index a3890f6ed6..fee626efed 100644 --- a/wpilibNewCommands/src/test/native/cpp/frc2/command/CommandTestBase.h +++ b/wpilibNewCommands/src/test/native/cpp/frc2/command/CommandTestBase.h @@ -18,6 +18,7 @@ #include "make_vector.h" namespace frc2 { + class CommandTestBase : public ::testing::Test { public: CommandTestBase(); @@ -93,4 +94,91 @@ class CommandTestBase : public ::testing::Test { void SetDSEnabled(bool enabled); }; + +template +class CommandTestBaseWithParam : public ::testing::TestWithParam { + public: + CommandTestBaseWithParam() { + auto& scheduler = CommandScheduler::GetInstance(); + scheduler.CancelAll(); + scheduler.Enable(); + scheduler.GetActiveButtonLoop()->Clear(); + } + + class TestSubsystem : public SubsystemBase {}; + + protected: + class MockCommand : public Command { + public: + MOCK_CONST_METHOD0(GetRequirements, wpi::SmallSet()); + MOCK_METHOD0(IsFinished, bool()); + MOCK_CONST_METHOD0(RunsWhenDisabled, bool()); + MOCK_METHOD0(Initialize, void()); + MOCK_METHOD0(Execute, void()); + MOCK_METHOD1(End, void(bool interrupted)); + + MockCommand() { + m_requirements = {}; + EXPECT_CALL(*this, GetRequirements()) + .WillRepeatedly(::testing::Return(m_requirements)); + EXPECT_CALL(*this, IsFinished()).WillRepeatedly(::testing::Return(false)); + EXPECT_CALL(*this, RunsWhenDisabled()) + .WillRepeatedly(::testing::Return(true)); + } + + MockCommand(std::initializer_list requirements, + bool finished = false, bool runWhenDisabled = true) { + m_requirements.insert(requirements.begin(), requirements.end()); + EXPECT_CALL(*this, GetRequirements()) + .WillRepeatedly(::testing::Return(m_requirements)); + EXPECT_CALL(*this, IsFinished()) + .WillRepeatedly(::testing::Return(finished)); + EXPECT_CALL(*this, RunsWhenDisabled()) + .WillRepeatedly(::testing::Return(runWhenDisabled)); + } + + MockCommand(MockCommand&& other) { + EXPECT_CALL(*this, IsFinished()) + .WillRepeatedly(::testing::Return(other.IsFinished())); + EXPECT_CALL(*this, RunsWhenDisabled()) + .WillRepeatedly(::testing::Return(other.RunsWhenDisabled())); + std::swap(m_requirements, other.m_requirements); + EXPECT_CALL(*this, GetRequirements()) + .WillRepeatedly(::testing::Return(m_requirements)); + } + + MockCommand(const MockCommand& other) : Command{other} {} + + void SetFinished(bool finished) { + EXPECT_CALL(*this, IsFinished()) + .WillRepeatedly(::testing::Return(finished)); + } + + ~MockCommand() { // NOLINT + auto& scheduler = CommandScheduler::GetInstance(); + scheduler.Cancel(this); + } + + protected: + std::unique_ptr TransferOwnership() && { // NOLINT + return std::make_unique(std::move(*this)); + } + + private: + wpi::SmallSet m_requirements; + }; + + CommandScheduler GetScheduler() { return CommandScheduler(); } + + void SetUp() override { frc::sim::DriverStationSim::SetEnabled(true); } + + void TearDown() override { + CommandScheduler::GetInstance().GetActiveButtonLoop()->Clear(); + } + + void SetDSEnabled(bool enabled) { + frc::sim::DriverStationSim::SetEnabled(enabled); + } +}; + } // namespace frc2 diff --git a/wpilibNewCommands/src/test/native/cpp/frc2/command/SchedulingRecursionTest.cpp b/wpilibNewCommands/src/test/native/cpp/frc2/command/SchedulingRecursionTest.cpp new file mode 100644 index 0000000000..6a6eae2a48 --- /dev/null +++ b/wpilibNewCommands/src/test/native/cpp/frc2/command/SchedulingRecursionTest.cpp @@ -0,0 +1,97 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +#include "CommandTestBase.h" +#include "frc2/command/CommandHelper.h" +#include "frc2/command/RunCommand.h" +#include "gtest/gtest.h" + +using namespace frc2; + +class SchedulingRecursionTest : public CommandTestBaseWithParam {}; + +class SelfCancellingCommand + : public CommandHelper { + public: + SelfCancellingCommand(CommandScheduler* scheduler, Subsystem* requirement) + : m_scheduler(scheduler) { + AddRequirements(requirement); + } + + void Initialize() override { m_scheduler->Cancel(this); } + + private: + CommandScheduler* m_scheduler; +}; + +/** + * Checks wpilibsuite/allwpilib#4259. + */ +TEST_P(SchedulingRecursionTest, CancelFromInitialize) { + CommandScheduler scheduler = GetScheduler(); + bool hasOtherRun = false; + TestSubsystem requirement; + SelfCancellingCommand selfCancels{&scheduler, &requirement}; + RunCommand other = + RunCommand([&hasOtherRun] { hasOtherRun = true; }, {&requirement}); + + scheduler.Schedule(GetParam(), &selfCancels); + scheduler.Run(); + // interruptibility of new arrival isn't checked + scheduler.Schedule(&other); + + EXPECT_FALSE(scheduler.IsScheduled(&selfCancels)); + EXPECT_TRUE(scheduler.IsScheduled(&other)); + scheduler.Run(); + EXPECT_TRUE(hasOtherRun); +} + +TEST_P(SchedulingRecursionTest, DefaultCommand) { + CommandScheduler scheduler = GetScheduler(); + bool hasOtherRun = false; + TestSubsystem requirement; + SelfCancellingCommand selfCancels{&scheduler, &requirement}; + RunCommand other = + RunCommand([&hasOtherRun] { hasOtherRun = true; }, {&requirement}); + scheduler.SetDefaultCommand(&requirement, std::move(other)); + + scheduler.Schedule(GetParam(), &selfCancels); + scheduler.Run(); + scheduler.Run(); + EXPECT_FALSE(scheduler.IsScheduled(&selfCancels)); + EXPECT_TRUE(scheduler.IsScheduled(scheduler.GetDefaultCommand(&requirement))); + scheduler.Run(); + EXPECT_TRUE(hasOtherRun); +} + +class CancelEndCommand : public CommandHelper { + public: + CancelEndCommand(CommandScheduler* scheduler, int& counter) + : m_scheduler(scheduler), m_counter(counter) {} + + void End(bool interrupted) override { + m_counter++; + m_scheduler->Cancel(this); + } + + private: + CommandScheduler* m_scheduler; + int& m_counter; +}; + +TEST_F(SchedulingRecursionTest, CancelFromEnd) { + CommandScheduler scheduler = GetScheduler(); + int counter = 0; + CancelEndCommand selfCancels{&scheduler, counter}; + + scheduler.Schedule(&selfCancels); + + EXPECT_NO_THROW({ scheduler.Cancel(&selfCancels); }); + EXPECT_EQ(1, counter); + EXPECT_FALSE(scheduler.IsScheduled(&selfCancels)); +} + +INSTANTIATE_TEST_SUITE_P(SchedulingRecursionTests, SchedulingRecursionTest, + testing::Bool());