diff --git a/commandsv3/src/main/java/org/wpilib/command3/Scheduler.java b/commandsv3/src/main/java/org/wpilib/command3/Scheduler.java index 67495c3888..f55f5c3679 100644 --- a/commandsv3/src/main/java/org/wpilib/command3/Scheduler.java +++ b/commandsv3/src/main/java/org/wpilib/command3/Scheduler.java @@ -902,7 +902,7 @@ public final class Scheduler implements ProtobufSerializable { * @return the currently running commands */ public Collection getRunningCommands() { - return Collections.unmodifiableSet(m_runningCommands.keySet()); + return List.copyOf(m_runningCommands.keySet()); } /** diff --git a/commandsv3/src/main/java/org/wpilib/command3/StateMachine.java b/commandsv3/src/main/java/org/wpilib/command3/StateMachine.java new file mode 100644 index 0000000000..df3f026ff1 --- /dev/null +++ b/commandsv3/src/main/java/org/wpilib/command3/StateMachine.java @@ -0,0 +1,592 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.command3; + +import static org.wpilib.util.ErrorMessages.requireNonNullParam; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.BooleanSupplier; +import java.util.function.Supplier; +import org.wpilib.annotation.NoDiscard; +import org.wpilib.annotation.PostConstructionInitializer; + +/** + * A declarative state machine that can be used to implement complex command routines. State machine + * setup should be done in stages: first, a state machine is created and its name is set; second, + * states are added to the state machine using {@link #addState(Command)}; third, transitions + * between states can be specified using {@link State#switchTo(State)}: + * + *
{@code
+ * // Declare the state machine
+ * StateMachine stateMachine = new StateMachine("Example State Machine");
+ *
+ * // Declare states
+ * State state1 = stateMachine.addState(...);
+ * State state2 = stateMachine.addState(...);
+ * State state3 = stateMachine.addState(...);
+ *
+ * // Set initial state
+ * stateMachine.setInitialState(state1);
+ *
+ * // Declare transitions
+ * state1.switchTo(state2).when(...);
+ * state2.switchTo(state3).when(...);
+ * }
+ * + *

Every state in a state machine runs a single command. While a state's command is running, the + * state machine will continually check all transitions that can be triggered from that state. If a + * transition is triggered, the state machine will cancel the state's command and move to the next + * state as defined by that transition. If no transition is triggered by the time the command + * completes, the state machine will exit unless a {@link + * TransitionNeedsConditionStage#whenComplete()} transition was specified from that state: + * + *

{@code
+ * // switch from state1 to state2 when foo is true
+ * state1.switchTo(state2).when(() -> foo == true);
+ *
+ * // but if foo never becomes true, switch to state3 when state1 finishes
+ * state1.switchTo(state3).whenComplete();
+ *
+ * // no transitions are defined from state2 or state3,
+ * // so the state machine will exit when either state completes
+ * }
+ */ +public final class StateMachine implements Command { + private final String m_name; + private State m_initialState = null; + private final List m_states = new ArrayList<>(); + + /** + * Creates a new state machine. + * + * @param name The name of the state machine. Cannot be null. This will appear in telemetry as the + * {@link Command#name() name} of the state machine. + */ + public StateMachine(String name) { + requireNonNullParam(name, "name", "StateMachine"); + m_name = name; + } + + @Override + public String name() { + return m_name; + } + + @Override + public Set requirements() { + // The machine itself doesn't have any requirements. Commands bound to the various states that + // the machine moves through may have requirements, however. + return Set.of(); + } + + /** + * Adds a new state to the state machine. State transitions can be specified on the new state + * using {@link State#switchTo(State)}. + * + * @param command The command for the state to execute. Cannot be null. + * @return The newly created state. + */ + @NoDiscard + public State addState(Command command) { + requireNonNullParam(command, "command", "StateMachine.addState"); + var state = new State(this, command); + m_states.add(state); + return state; + } + + /** + * Sets up a transition from any of the given states to a specific state. If no states are given, + * the transition will apply to all states in the state machine at the time this method is + * called. + * + *
{@code
+   * stateMachine.switchFromAny(state1, state2, state3).to(state4).when(() -> foo == true);
+   *
+   * // Functionally equivalent to:
+   * state1.switchTo(state4).when(() -> foo == true);
+   * state2.switchTo(state4).when(() -> foo == true);
+   * state3.switchTo(state4).when(() -> foo == true);
+   *
+   * // Set up an early exit condition from any state
+   * stateMachine.switchFromAny().toExitStateMachine().when(() -> bar == true);
+   *
+   * // Functionally equivalent to:
+   * state1.exitStateMachine().when(() -> bar == true);
+   * state2.exitStateMachine().when(() -> bar == true);
+   * state3.exitStateMachine().when(() -> bar == true);
+   * state4.exitStateMachine().when(() -> bar == true);
+   * }
+ * + * @param states The states to transition from. + * @return A builder for the transition. + */ + public TransitionNeedsTargetStage switchFromAny(State... states) { + if (states.length == 0) { + return new TransitionNeedsTargetStage(List.copyOf(m_states)); + } else { + return new TransitionNeedsTargetStage(List.of(states)); + } + } + + /** + * Sets the initial state for the state machine. This must be called before the state machine is + * scheduled. Failure to do so will result in an {@link IllegalStateException} being thrown when + * the state machine is started. Usage of this method is enforced by the WPILib compiler plugin; + * creating a state machine and neglecting to call this method will result in a compilation error. + * + * @param initialState The new initial state. Cannot be null. + * @see PostConstructionInitializer + */ + @PostConstructionInitializer + public void setInitialState(State initialState) { + requireNonNullParam(initialState, "initialState", "StateMachine.setInitialState"); + if (!this.equals(initialState.m_stateMachine)) { + throw new IllegalArgumentException("Cannot set initial state in a different state machine"); + } + m_initialState = initialState; + } + + @Override + public void run(Coroutine coroutine) { + if (m_initialState == null) { + throw new IllegalStateException( + m_name + " does not have an initial state. Use .setInitialState() to provide one."); + } + + var currentState = m_initialState; + + outer_loop: + while (currentState != null) { + final var currentCommand = currentState.command(); + coroutine.fork(currentCommand); + currentState.runEnterCallbacks(); + boolean didYield = false; + + while (coroutine.scheduler().isRunning(currentCommand)) { + for (var transition : currentState.transitions()) { + if (transition.shouldTransition()) { + // Cancel the current state's command and move to the next state specified by the + // transition. Break the state loop early to avoid an unnecessary yield() call and + // allow the next state's command to start in the same loop iteration that the + // previous state completed. If the next state is null, the state machine will exit + // immediately. + // Note: to prevent infinite loops when states transition to themselves, we require + // the transition signal to be a rising edge on the user-supplied condition to ensure + // that the transition is only triggered once per loop iteration. + currentState.runExitCallbacks(); + coroutine.scheduler().cancel(currentCommand); + currentState = verifyState(transition.nextState()); + continue outer_loop; + } + } + + // Yield after checking all transitions. + // Note: this will be skipped if a transition is triggered. + coroutine.yield(); + didYield = true; + } + + // Move to the next configured state if no transition was hit before the command completed. + // We need to be careful about states with oneshot commands; they will complete immediately + // in the `fork()` call above and never enter the `while` loop and thus never yield. + // Therefore, we inject a yield call at the end here to ensure that the state machine will + // always yield once per state. This has a downside of adding extra loop cycles to states that + // may not need them (and has slightly different behavior to SequentialCommandGroup, which + // runs commands as fast as possible). + currentState.runExitCallbacks(); + currentState = verifyState(currentState.nextState()); + if (!didYield && currentState != null) { + // No need to yield if we're exiting the state machine + coroutine.yield(); + } + } + } + + private State verifyState(State next) { + if (next == null || this.equals(next.m_stateMachine)) { + // OK + return next; + } + + // Bad user setup + throw new IllegalStateException( + "The next state does not belong to this state machine. Check the state for " + + next.command().name()); + } + + /** + * A state in a state machine. Each state has a command that will be run when it is active. States + * can transition to other states when some condition is met when that state is active, or + * automatically transition to another state when it completes if no transition conditions were + * met. A state with no transitions will never transition to another state, and will cause the + * state machine to exit when the state completes; likewise, a state with no incoming transitions + * will never be active. + */ + public static final class State { + /** The state machine that this state belongs to. */ + private final StateMachine m_stateMachine; + + /** The command that will run when this state is active. */ + private final Command m_command; + + /** The possible states to transition to when this state completes. */ + private final List m_completions = new ArrayList<>(); + + /** The state to transition to by default when this state completes. May be null. */ + private Supplier m_defaultNextState = () -> null; + + /** + * The transitions that can be triggered from this state. If multiple transitions are triggered + * at once, the first transition in the list will be used. + */ + private final List m_transitions = new ArrayList<>(); + + private final List m_enterCallbacks = new ArrayList<>(); + private final List m_exitCallbacks = new ArrayList<>(); + + private State(StateMachine stateMachine, Command command) { + m_stateMachine = stateMachine; + m_command = command; + } + + private Command command() { + return m_command; + } + + private List transitions() { + return m_transitions; + } + + private void addTransition(Transition transition) { + m_transitions.add(transition); + } + + /** + * Sets the next state to transition to when this state completes without having fired a + * transition first, or if no conditional completion transition has been met. + * + * @param nextState A supplier for the next state to transition to. Cannot be null, but may + * return null. + */ + private void setNextState(Supplier nextState) { + m_defaultNextState = nextState; + } + + // Custom boolean supplier classes may override .equals to do boolean value comparisons, + // particularly in Kotlin code. Check reference equality instead to just remove bindings to + // the same condition object. + @SuppressWarnings("PMD.CompareObjectsWithEquals") + private void addCompletion(BooleanSupplier condition, Supplier next) { + // Remove any preexisting completion with the same condition + m_completions.removeIf(c -> c.getCondition() == condition); + m_completions.add(new Completion(next, condition)); + } + + private State nextState() { + for (var completion : m_completions) { + if (completion.shouldTransition()) { + return completion.nextState(); + } + } + + // No conditional transition has been met, use the default next state. + // If this was never set or was set to be null, the state machine will exit. + return m_defaultNextState.get(); + } + + private void runEnterCallbacks() { + m_enterCallbacks.forEach(Runnable::run); + } + + private void runExitCallbacks() { + m_exitCallbacks.forEach(Runnable::run); + } + + /** + * Adds a function to be called when this state is entered. Callbacks are invoked immediately + * after the state's command is scheduled, and are run in the same order they were added. + * + *

Note: if a callback schedules any commands, those commands will be scoped to the lifetime + * of the entire state machine, not this state's lifetime. + * + * @param callback The callback to run. Cannot be null. + */ + public void onEnter(Runnable callback) { + requireNonNullParam(callback, "callback", "State.onEnter"); + m_enterCallbacks.add(callback); + } + + /** + * Adds a function to be called when this state is exited. Callbacks are invoked immediately + * before the state's command is canceled, and are run in the order they were added. If the + * command finishes naturally, the callbacks are run immediately after it completes and before + * the next state is entered. + * + * @param callback The callback to run. Cannot be null. + */ + public void onExit(Runnable callback) { + requireNonNullParam(callback, "callback", "State.onExit"); + m_exitCallbacks.add(callback); + } + + /** + * Starts building a transition to the specified state. + * + * @param to The state to transition to. Cannot be null. + * @return A builder for the transition. + */ + public TransitionNeedsConditionStage switchTo(State to) { + requireNonNullParam(to, "to", "State.switchTo"); + if (!m_stateMachine.equals(to.m_stateMachine)) { + throw new IllegalArgumentException( + "Cannot transition to a state in a different state machine"); + } + return new TransitionNeedsTargetStage(List.of(this)).to(to); + } + + /** + * Starts build a transition to some dynamic state. The supplier will be evaluated at the time + * the transition's condition is met. + * + * @param dynamic The dynamic state supplier. Cannot be null. + * @return A builder for the transition. + */ + public TransitionNeedsConditionStage switchTo(Supplier dynamic) { + requireNonNullParam(dynamic, "dynamic", "State.switchTo"); + // Unfortunately, we can't check up front that the supplier will always return a state for + // this state machine. The output will need to be checked when the supplier is called + return new TransitionNeedsTargetStage(List.of(this)).to(dynamic); + } + + /** + * Starts building a transition that will exit the state machine when triggered, rather than + * moving to a different state. + * + * @return A builder for the transition. + */ + public TransitionNeedsConditionStage exitStateMachine() { + return new TransitionNeedsConditionStage(List.of(this), () -> null); + } + } + + /** + * A builder for a transition from one state to another. Use {@link #to(State)} to specify the + * target state to transition to. + */ + @NoDiscard("Use .to() to specify the target state") + public static final class TransitionNeedsTargetStage { + private final List m_from; + + private TransitionNeedsTargetStage(List from) { + m_from = from; + } + + /** + * Specifies the target state to transition to. + * + * @param to The state to transition to. Cannot be null. + * @return A builder to specify the transition condition. + */ + public TransitionNeedsConditionStage to(State to) { + requireNonNullParam(to, "to", "NeedsTargetTransitionBuilder.to"); + for (var state : m_from) { + if (!state.m_stateMachine.equals(to.m_stateMachine)) { + throw new IllegalArgumentException( + "Cannot transition to a state in a different state machine"); + } + } + return new TransitionNeedsConditionStage(m_from, () -> to); + } + + /** + * Specifies a dynamic target state to transition to. The supplier will be evaluated at the time + * the transition condition is met. + * + * @param dynamic A dynamic supplier for next states. Cannot be null. + * @return A builder to specify the transition condition. + */ + public TransitionNeedsConditionStage to(Supplier dynamic) { + requireNonNullParam(dynamic, "dynamic", "NeedsTargetTransitionBuilder.to"); + return new TransitionNeedsConditionStage(m_from, dynamic); + } + + /** + * Specifies the transition will exit the state machine when triggered, rather than moving to a + * different state. + * + * @return A builder to specify the transition condition. + */ + public TransitionNeedsConditionStage toExitStateMachine() { + return new TransitionNeedsConditionStage(m_from, () -> null); + } + } + + /** + * A builder to set conditions for a transition from one state to another. Use {@link + * #when(BooleanSupplier)} to make the transition occur when some external condition becomes true, + * or use {@link #whenComplete()} to make the transition occur when the originating state + * completes without having reached any other transitions first. + */ + @NoDiscard("Use .when() or .whenComplete() to specify the transition condition") + public static final class TransitionNeedsConditionStage { + private final List m_originatingStates; + + // Note: A null result from the supplier indicates that the transition will cause the state + // machine to exit + private final Supplier m_targetStateSupplier; + + private TransitionNeedsConditionStage(List from, Supplier to) { + requireNonNullParam(from, "from", "TransitionNeedsConditionStage"); + requireNonNullParam(to, "to", "TransitionNeedsConditionStage"); + + m_originatingStates = from; + m_targetStateSupplier = to; + } + + /** + * Adds a transition that will be triggered when the specified condition becomes true. + * + *

NOTE: this had no effect if the originating state is a one-shot command without a + * yield. Use {@link #whenComplete()} instead for transitions from one-shot commands. + * + *

If multiple transitions are triggered in the same scheduler loop iteration, the first + * transition will fire and the rest will be ignored. + * + *

{@code
+     * StateMachine stateMachine = new StateMachine("Example State Machine");
+     * State state1 = stateMachine.addState(...);
+     * State state2 = stateMachine.addState(...);
+     * State state3 = stateMachine.addState(...);
+     *
+     * state1.switchTo(state2).when(() -> foo == true);
+     *
+     * // never triggers because the first transition will be evaluated first
+     * state1.switchTo(state3).when(() -> foo == true);
+     * }
+ * + * @param condition The condition that will trigger the transition. Cannot be null. + */ + public void when(BooleanSupplier condition) { + requireNonNullParam(condition, "condition", "NeedsConditionTransitionBuilder.when"); + var transition = new Transition(m_targetStateSupplier, condition); + m_originatingStates.forEach(originatingState -> originatingState.addTransition(transition)); + } + + /** + * Adds a transition to the target state when the originating state completes without having + * triggered any other transitions first. If this is called multiple times for the same + * originating state, later calls will override the previous transitions. Any {@link + * #whenCompleteAnd} transitions will take precedence over {@code whenComplete} transitions if + * their conditions are met when the state exits. + * + *
{@code
+     * StateMachine stateMachine = new StateMachine("Example State Machine");
+     * State state1 = stateMachine.addState(...);
+     * State state2 = stateMachine.addState(...);
+     * State state3 = stateMachine.addState(...);
+     *
+     * state1.switchTo(state2).whenComplete();
+     * state1.switchTo(state3).whenComplete(); // Overrides the previous transition
+     * state1.exitStateMachine().whenCompleteAnd(...); // Takes precedence if the condition is met
+     * }
+ */ + public void whenComplete() { + m_originatingStates.forEach(state -> state.setNextState(m_targetStateSupplier)); + } + + /** + * Similar to {@link #when(BooleanSupplier)}, but only triggers when the originating state + * completes and some other condition is also met. {@code whenCompleteAnd} transitions + * will be evaluated in declaration order and take precedence over any {@link #whenComplete()} + * transitions that have been specified. + * + *
{@code
+     * StateMachine stateMachine = new StateMachine("Example State Machine");
+     * State state1 = stateMachine.addState(...);
+     * State state2 = stateMachine.addState(...);
+     * State state3 = stateMachine.addState(...);
+     *
+     * state1.switchTo(state2).whenComplete();
+     * state1.switchTo(state3).whenComplete(); // Overrides the previous transition
+     * state1.exitStateMachine().whenCompleteAnd(...); // Takes precedence if the condition is met
+     * }
+ * + * @param condition The condition that will trigger the transition. + */ + public void whenCompleteAnd(BooleanSupplier condition) { + requireNonNullParam(condition, "condition", "NeedsConditionTransitionBuilder.whenComplete"); + m_originatingStates.forEach(state -> state.addCompletion(condition, m_targetStateSupplier)); + } + } + + /** + * Similar to {@link Transition}, but does not track the state of the condition. This is intended + * to only be checked once, when the originating state completes. + */ + private static final class Completion { + private final Supplier m_nextSupplier; + private final BooleanSupplier m_condition; + + /** + * Creates a new completion object. + * + * @param next A supplier for the state to transition to when the originating state completes. + * @param condition The condition that will trigger the transition. + */ + private Completion(Supplier next, BooleanSupplier condition) { + m_nextSupplier = next; + m_condition = condition; + } + + private boolean shouldTransition() { + return m_condition.getAsBoolean(); + } + + public State nextState() { + return m_nextSupplier.get(); + } + + public BooleanSupplier getCondition() { + return m_condition; + } + } + + /** + * Similar to {@link Completion}, but tracks the state of the condition to avoid infinite loops. + * This is intended to be checked every loop while the originating state is active. + */ + private static final class Transition { + /** The state to transition to. */ + private final Supplier m_nextSupplier; + + /** The condition that will trigger the transition. */ + private final BooleanSupplier m_condition; + + private boolean m_previousSignal = false; + + private Transition(Supplier next, BooleanSupplier condition) { + m_nextSupplier = next; + m_condition = condition; + } + + /** Checks if the transition should be triggered. */ + private boolean shouldTransition() { + // Wrap the condition in a rising edge detector so that it will only trigger a single time per + // loop iteration. This prevents issues with a state transitioning to itself like so: + // state1.switchTo(state1).when(() -> foo == true); + // If the condition is itself a rising edge detector, this wrapping is redundant but harmless. + boolean currentValue = m_condition.getAsBoolean(); + boolean isRisingEdge = currentValue && !m_previousSignal; + m_previousSignal = currentValue; + return isRisingEdge; + } + + private State nextState() { + return m_nextSupplier.get(); + } + } +} diff --git a/commandsv3/src/test/java/org/wpilib/command3/CoroutineTest.java b/commandsv3/src/test/java/org/wpilib/command3/CoroutineTest.java index 69f5718dbb..546a0f2705 100644 --- a/commandsv3/src/test/java/org/wpilib/command3/CoroutineTest.java +++ b/commandsv3/src/test/java/org/wpilib/command3/CoroutineTest.java @@ -8,7 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; -import java.util.Set; +import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -157,6 +157,6 @@ class CoroutineTest extends CommandTestBase { assertTrue(ranAfterAwait.get()); // But only the outer command should still be running; secondInner should have been canceled - assertEquals(Set.of(outer), m_scheduler.getRunningCommands()); + assertEquals(List.of(outer), m_scheduler.getRunningCommands()); } } diff --git a/commandsv3/src/test/java/org/wpilib/command3/StateMachineTest.java b/commandsv3/src/test/java/org/wpilib/command3/StateMachineTest.java new file mode 100644 index 0000000000..c35375c545 --- /dev/null +++ b/commandsv3/src/test/java/org/wpilib/command3/StateMachineTest.java @@ -0,0 +1,738 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.command3; + +import static org.junit.jupiter.api.Assertions.assertAll; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.wpilib.command3.SchedulerEvent.Canceled; +import static org.wpilib.command3.SchedulerEvent.Mounted; +import static org.wpilib.command3.SchedulerEvent.Scheduled; +import static org.wpilib.command3.SchedulerEvent.Yielded; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; +import org.wpilib.annotation.PostConstructionInitializer; + +@SuppressWarnings("PMD.CompareObjectsWithEquals") +class StateMachineTest extends CommandTestBase { + @Test + @SuppressWarnings(PostConstructionInitializer.SUPPRESSION_KEY) + void errorsWithoutInitialState() { + Mechanism mech = new Mechanism("Mechanism", m_scheduler); + Command command1 = mech.run(Coroutine::park).named("Command 1"); + Command command2 = mech.run(Coroutine::park).named("Command 2"); + + StateMachine stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + // stateMachine.setInitialState(state1); // Oops, someone forgot to set the initial state! + state1.switchTo(state2).whenComplete(); + + m_scheduler.schedule(stateMachine); + + // Don't worry, it'll be caught at runtime. + // It would actually be caught at compile time, but we disabled the compiler check for this test + var exception = assertThrows(IllegalStateException.class, () -> m_scheduler.run()); + assertEquals( + "State Machine does not have an initial state. Use .setInitialState() to provide one.", + exception.getMessage()); + assertFalse(m_scheduler.isRunning(stateMachine), "State machine should not be running"); + } + + @Test + void initialStateCanBeOverridden() { + Mechanism mech = new Mechanism("Mechanism", m_scheduler); + Command command1 = mech.run(Coroutine::park).named("Command 1"); + Command command2 = mech.run(Coroutine::park).named("Command 2"); + + StateMachine stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + stateMachine.setInitialState(state1); + stateMachine.setInitialState(state2); + state2.switchTo(state1).whenComplete(); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertTrue(m_scheduler.isRunning(command2), "Command 2 should be running as the initial state"); + assertFalse(m_scheduler.isRunning(command1), "Command 1 should not be running"); + } + + @Test + void transitions() { + AtomicBoolean signalA = new AtomicBoolean(false); + AtomicBoolean signalB = new AtomicBoolean(false); + + Mechanism mech = new Mechanism("Mechanism", m_scheduler); + var command1 = mech.run(Coroutine::park).named("Command 1"); + var command2 = mech.run(Coroutine::park).named("Command 2"); + var command3 = mech.run(Coroutine::park).named("Command 3"); + + StateMachine stateMachine = new StateMachine("State Machine"); + + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + + stateMachine.setInitialState(state1); + + state1.switchTo(state2).when(signalA::get); + state2.switchTo(state3).when(signalB::get); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertAll( + () -> assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"), + () -> assertTrue(m_scheduler.isRunning(command1), "Command 1 should be running"), + () -> assertFalse(m_scheduler.isRunning(command2), "Command 2 should not be running"), + () -> assertFalse(m_scheduler.isRunning(command3), "Command 3 should not be running")); + + signalA.set(true); + m_scheduler.run(); + assertAll( + () -> assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"), + () -> assertFalse(m_scheduler.isRunning(command1), "Command 1 should not be running"), + () -> assertTrue(m_scheduler.isRunning(command2), "Command 2 should be running"), + () -> assertFalse(m_scheduler.isRunning(command3), "Command 3 should not be running")); + + signalB.set(true); + m_scheduler.run(); + assertAll( + () -> assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"), + () -> assertFalse(m_scheduler.isRunning(command1), "Command 1 should not be running"), + () -> assertFalse(m_scheduler.isRunning(command2), "Command 2 should not be running"), + () -> assertTrue(m_scheduler.isRunning(command3), "Command 3 should be running")); + } + + @Test + void transitionsIfConditionIsAlreadyTrueWhenEntered() { + var command1 = Command.noRequirements().executing(Coroutine::park).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + + var signal = new AtomicBoolean(false); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + stateMachine.setInitialState(state1); + state1.switchTo(state2).when(signal::get); + + signal.set(true); + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertFalse(m_scheduler.isRunning(command1), "Command 1 should not be running"); + assertTrue(m_scheduler.isRunning(command2), "State 1 should have transitioned to State 2"); + } + + @Test + void commandExits() { + AtomicBoolean signal = new AtomicBoolean(false); + + var command1 = + Command.noRequirements().executing(co -> co.waitUntil(signal::get)).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + + stateMachine.setInitialState(state1); + + state1.switchTo(state2).whenComplete(); + state2.exitStateMachine().whenComplete(); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertTrue(m_scheduler.isRunning(command1), "Command 1 should be running"); + + signal.set(true); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertFalse(m_scheduler.isRunning(command1), "Command 1 should have ended"); + assertTrue(m_scheduler.isRunning(command2), "Command 2 should have started"); + } + + @Test + void stateTransitionsToSelf() { + AtomicBoolean signal = new AtomicBoolean(false); + AtomicInteger initCount = new AtomicInteger(0); + + var command = + Command.noRequirements() + .executing( + co -> { + initCount.incrementAndGet(); + co.park(); + }) + .named("Command"); + var stateMachine = new StateMachine("State Machine"); + var state = stateMachine.addState(command); + stateMachine.setInitialState(state); + state.switchTo(state).when(signal::get); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertEquals(1, initCount.get(), "Command should be initialized once"); + + signal.set(true); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should still be running"); + assertEquals(2, initCount.get(), "Command should have reinitialized"); + + assertEquals(14, m_events.size()); + assertAll( + // First run + () -> assertTrue(m_events.get(0) instanceof Scheduled s && s.command() == stateMachine), + () -> assertTrue(m_events.get(1) instanceof Mounted m && m.command() == stateMachine), + () -> assertTrue(m_events.get(2) instanceof Scheduled s && s.command() == command), + () -> assertTrue(m_events.get(3) instanceof Mounted m && m.command() == command), + () -> assertTrue(m_events.get(4) instanceof Yielded y && y.command() == command), + () -> assertTrue(m_events.get(5) instanceof Yielded y && y.command() == stateMachine), + () -> assertTrue(m_events.get(6) instanceof Mounted m && m.command() == command), + () -> assertTrue(m_events.get(7) instanceof Yielded y && y.command() == command), + // Second run + () -> assertTrue(m_events.get(8) instanceof Mounted m && m.command() == stateMachine), + () -> assertTrue(m_events.get(9) instanceof Canceled c && c.command() == command), + () -> assertTrue(m_events.get(10) instanceof Scheduled s && s.command() == command), + () -> assertTrue(m_events.get(11) instanceof Mounted m && m.command() == command), + () -> assertTrue(m_events.get(12) instanceof Yielded y && y.command() == command), + () -> assertTrue(m_events.get(13) instanceof Yielded y && y.command() == stateMachine)); + } + + @Test + void oneshotCommandTransitionsToSelfOnComplete() { + AtomicInteger count = new AtomicInteger(0); + var command = Command.noRequirements().executing(c -> count.incrementAndGet()).named("Command"); + var stateMachine = new StateMachine("State Machine"); + var state = stateMachine.addState(command); + stateMachine.setInitialState(state); + state.switchTo(state).whenComplete(); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertEquals(1, count.get(), "Command should have run once"); + } + + @Test + void onlyFirstExplicitTransitionFires() { + var signal = new AtomicBoolean(false); + + var command1 = Command.noRequirements().executing(Coroutine::park).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + var command3 = Command.noRequirements().executing(Coroutine::park).named("Command 3"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + + stateMachine.setInitialState(state1); + state1.switchTo(state2).when(signal::get); + state1.switchTo(state3).when(signal::get); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(List.of(stateMachine, command1), m_scheduler.getRunningCommands()); + + signal.set(true); + m_scheduler.run(); + assertEquals(List.of(stateMachine, command2), m_scheduler.getRunningCommands()); + } + + @Test + void onlyLastWhenCompleteTransitionFires() { + var command1 = Command.noRequirements().executing(Coroutine::yield).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::yield).named("Command 2"); + var command3 = Command.noRequirements().executing(Coroutine::yield).named("Command 3"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + + stateMachine.setInitialState(state1); + state1.switchTo(state2).whenComplete(); + state1.switchTo(state3).whenComplete(); // overrides the previous transition + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(List.of(stateMachine, command1), m_scheduler.getRunningCommands()); + + m_scheduler.run(); + assertEquals(List.of(stateMachine, command3), m_scheduler.getRunningCommands()); + } + + @Test + void whenCompleteAndTakesPriorityOverWhenCompleteIfCalledLast() { + var signal = new AtomicBoolean(false); + + var command1 = Command.noRequirements().executing(Coroutine::yield).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::yield).named("Command 2"); + var command3 = Command.noRequirements().executing(Coroutine::yield).named("Command 3"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + + stateMachine.setInitialState(state1); + state1.switchTo(state2).whenComplete(); + state1.switchTo(state3).whenCompleteAnd(signal::get); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(List.of(stateMachine, command1), m_scheduler.getRunningCommands()); + + signal.set(true); + m_scheduler.run(); + assertEquals( + List.of(stateMachine, command3), // would be command2 if `whenComplete` took precedence + m_scheduler.getRunningCommands()); + } + + @Test + void whenCompleteAndTakesPriorityOverWhenCompleteIfCalleFirst() { + var signal = new AtomicBoolean(false); + + var command1 = Command.noRequirements().executing(Coroutine::yield).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::yield).named("Command 2"); + var command3 = Command.noRequirements().executing(Coroutine::yield).named("Command 3"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + + stateMachine.setInitialState(state1); + state1.switchTo(state3).whenCompleteAnd(signal::get); + state1.switchTo(state2).whenComplete(); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(List.of(stateMachine, command1), m_scheduler.getRunningCommands()); + + signal.set(true); + m_scheduler.run(); + assertEquals( + List.of(stateMachine, command3), // would be command3 if `whenCompleteAnd` took precedence + m_scheduler.getRunningCommands()); + } + + @Test + void composingComplete() { + AtomicBoolean signal = new AtomicBoolean(false); + var command1 = Command.noRequirements().executing(Coroutine::yield).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + + stateMachine.setInitialState(state1); + state1.exitStateMachine().whenComplete(); + state1.switchTo(state2).whenCompleteAnd(signal::get); + + // First run, signal is low - state machine exits on state completion + { + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertTrue(m_scheduler.isRunning(command1), "Command should be running"); + + m_scheduler.run(); + assertFalse(m_scheduler.isRunning(stateMachine), "State machine should have exited"); + } + + // Second run, signal goes high - state machine switches to state2 instead of exiting + { + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertTrue(m_scheduler.isRunning(command1), "Command should be running"); + + signal.set(true); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertFalse(m_scheduler.isRunning(command1), "Command should have ended"); + assertTrue(m_scheduler.isRunning(command2), "Command 2 should have started"); + } + } + + @Test + void switchFromAny() { + var command1 = Command.noRequirements().executing(Coroutine::yield).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + var command3 = Command.noRequirements().executing(Coroutine::park).named("Command 3"); + + AtomicBoolean signal = new AtomicBoolean(false); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + + stateMachine.setInitialState(state1); + stateMachine.switchFromAny(state1, state2).to(state3).when(signal::get); + state1.switchTo(state2).whenComplete(); + + // transition from 1 -> 3 + { + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertTrue(m_scheduler.isRunning(command1), "Command 1 should be running"); + + signal.set(true); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertFalse(m_scheduler.isRunning(command1), "Command 1 should have ended"); + assertTrue(m_scheduler.isRunning(command3), "Command 3 should have started"); + } + + m_scheduler.cancel(stateMachine); + signal.set(false); + + // transition from 2 -> 3 + { + m_scheduler.schedule(stateMachine); + m_scheduler.run(); // yield 1 + assertEquals( + List.of("State Machine", "Command 1"), + m_scheduler.getRunningCommands().stream().map(Command::name).toList()); + + m_scheduler.run(); // transition 1 -> 2 + assertEquals( + List.of("State Machine", "Command 2"), + m_scheduler.getRunningCommands().stream().map(Command::name).toList()); + + signal.set(true); + m_scheduler.run(); // transition 2 -> 3 + assertEquals( + List.of("State Machine", "Command 3"), + m_scheduler.getRunningCommands().stream().map(Command::name).toList()); + } + } + + @Test + void switchToSupplierWhenComplete() { + AtomicInteger count = new AtomicInteger(0); + + var command1 = + Command.noRequirements() + .executing( + co -> { + count.incrementAndGet(); + co.yield(); + }) + .named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + var command3 = Command.noRequirements().executing(Coroutine::park).named("Command 3"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + stateMachine.setInitialState(state1); + state1 + .switchTo( + () -> { + if (count.get() == 1) { + return state2; + } else { + return state3; + } + }) + .whenComplete(); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); // command 1 increments the count and then yields + assertEquals(List.of(stateMachine, command1), m_scheduler.getRunningCommands()); + + // command 1 completes, state machine moves to the next state + // if the supplier is checked at configuration time, the count would be 0 and return state3 + // if the supplier is checked at runtime, the count would be 1 and return state2 + m_scheduler.run(); + assertEquals(List.of(stateMachine, command2), m_scheduler.getRunningCommands()); + } + + @Test + void switchToSupplierWithCondition() { + AtomicInteger count = new AtomicInteger(0); + var command1 = + Command.noRequirements() + .executing( + co -> { + while (true) { + // Increment after yielding. Otherwise, the condition is checked and the state + // machine immediately switches to the next state all within the first cycle; + // the running command1 is never observed. + co.yield(); + count.incrementAndGet(); + } + }) + .named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + var command3 = Command.noRequirements().executing(Coroutine::park).named("Command 3"); + + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + var state3 = stateMachine.addState(command3); + stateMachine.setInitialState(state1); + state1 + .switchTo( + () -> { + if (count.get() == 1) { + return state2; + } else { + return state3; + } + }) + .when(() -> count.get() == 1); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(List.of(stateMachine, command1), m_scheduler.getRunningCommands()); + + m_scheduler.run(); + assertEquals(List.of(stateMachine, command2), m_scheduler.getRunningCommands()); + } + + @Test + void runsOnEnterForInitialState() { + var command1 = Command.noRequirements().executing(Coroutine::park).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + + AtomicInteger enterCount = new AtomicInteger(0); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + stateMachine.setInitialState(state1); + state1.onEnter(enterCount::incrementAndGet); + state1.switchTo(state2).whenComplete(); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(1, enterCount.get(), "onEnter should have been called once"); + } + + @Test + void runsOnExitOnTransition() { + var command1 = Command.noRequirements().executing(Coroutine::park).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + + AtomicInteger exitCount = new AtomicInteger(0); + AtomicBoolean signal = new AtomicBoolean(false); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + stateMachine.setInitialState(state1); + state1.onExit(exitCount::incrementAndGet); + state1.switchTo(state2).when(signal::get); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(0, exitCount.get(), "onExit should not have been called"); + + signal.set(true); + m_scheduler.run(); + assertEquals(1, exitCount.get(), "onExit should have been called"); + } + + @Test + void runsOnExitWhenComplete() { + var command1 = Command.noRequirements().executing(co -> {}).named("Command 1"); + + AtomicInteger exitCount = new AtomicInteger(0); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + stateMachine.setInitialState(state1); + state1.onExit(exitCount::incrementAndGet); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals(1, exitCount.get(), "onExit should have been called"); + assertFalse(m_scheduler.isRunning(command1), "State should have exited"); + } + + @Test + void onExitCanSchedule() { + var mech = new Mechanism("Mechanism", m_scheduler); + var mainMechCommand = mech.run(Coroutine::park).named("Main Mech Command"); + var backgroundMechCommand = mech.run(Coroutine::park).named("Background Mech Command"); + var nextStateCommand = Command.noRequirements().executing(Coroutine::park).named("Next"); + + AtomicBoolean signal = new AtomicBoolean(false); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(mainMechCommand); + var state2 = stateMachine.addState(nextStateCommand); + stateMachine.setInitialState(state1); + state1.switchTo(state2).when(signal::get); + state1.onExit(() -> m_scheduler.schedule(backgroundMechCommand)); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertTrue(m_scheduler.isRunning(mainMechCommand), "Main Mechanism should be running"); + + signal.set(true); + m_scheduler.run(); + assertTrue(m_scheduler.isRunning(stateMachine), "State machine should be running"); + assertFalse(m_scheduler.isRunning(mainMechCommand), "Main Mechanism should have ended"); + assertTrue( + m_scheduler.isRunning(backgroundMechCommand), "Background Mechanism should have started"); + assertTrue(m_scheduler.isRunning(nextStateCommand), "Next State should have started"); + } + + @Test + void runsOnEnterCallbacksInInsertionOrder() { + var command1 = Command.noRequirements().executing(co -> {}).named("Command 1"); + + List callbackInfo = new ArrayList<>(); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + stateMachine.setInitialState(state1); + state1.onEnter(() -> callbackInfo.add("onEnter 1")); + state1.onEnter(() -> callbackInfo.add("onEnter 2")); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals( + List.of("onEnter 1", "onEnter 2"), callbackInfo, "onEnter callbacks did not run correctly"); + } + + @Test + void runsOnExitCallbacksInInsertionOrder() { + // Make the command immediately exit + var command1 = Command.noRequirements().executing(co -> {}).named("Command 1"); + + List callbackInfo = new ArrayList<>(); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + stateMachine.setInitialState(state1); + state1.onExit(() -> callbackInfo.add("onExit 1")); + state1.onExit(() -> callbackInfo.add("onExit 2")); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertEquals( + List.of("onExit 1", "onExit 2"), callbackInfo, "onExit callbacks did not run correctly"); + } + + @Test + void onEnterSeesNewCommand() { + var command1 = Command.noRequirements().executing(Coroutine::park).named("Command 1"); + + AtomicBoolean sawCommand1OnEnter = new AtomicBoolean(false); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + stateMachine.setInitialState(state1); + state1.onEnter(() -> sawCommand1OnEnter.set(m_scheduler.isRunning(command1))); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + assertTrue(sawCommand1OnEnter.get(), "onEnter should have seen the command running"); + } + + @Test + void onExitWithTransitionSeesExitedCommand() { + var command1 = Command.noRequirements().executing(Coroutine::park).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + + AtomicBoolean sawCommand1OnExit = new AtomicBoolean(false); + AtomicBoolean signal = new AtomicBoolean(false); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + stateMachine.setInitialState(state1); + state1.onExit(() -> sawCommand1OnExit.set(m_scheduler.isRunning(command1))); + state1.switchTo(state2).when(signal::get); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); + + signal.set(true); + m_scheduler.run(); + assertTrue(sawCommand1OnExit.get(), "onExit should have seen the exiting command"); + } + + // Because completion is defined as the command finishing on its own, callbacks will never + // be able to see the command running in the scheduler because they're invoked _after_ the + // command has finished. + @Test + void onExitWithCompleteCannotSeeExitedCommand() { + var command1 = Command.noRequirements().executing(Coroutine::yield).named("Command 1"); + var command2 = Command.noRequirements().executing(Coroutine::park).named("Command 2"); + + AtomicBoolean onExitCalled = new AtomicBoolean(false); + AtomicBoolean sawCommand1OnExit = new AtomicBoolean(false); + var stateMachine = new StateMachine("State Machine"); + var state1 = stateMachine.addState(command1); + var state2 = stateMachine.addState(command2); + stateMachine.setInitialState(state1); + state1.onExit( + () -> { + onExitCalled.set(true); + sawCommand1OnExit.set(m_scheduler.isRunning(command1)); + }); + state1.switchTo(state2).whenComplete(); + + m_scheduler.schedule(stateMachine); + m_scheduler.run(); // command yields... + assertFalse(onExitCalled.get(), "onExit should not have been called yet"); + + m_scheduler.run(); // ...then exits here + assertTrue(onExitCalled.get(), "onExit should have been called"); + assertFalse(sawCommand1OnExit.get(), "exiting command should be invisible"); + } + + @Test + void ledStateMachine() { + var leds = + new Mechanism("LEDs", m_scheduler) { + Command idleAnimation() { + return run(Coroutine::park).withPriority(-1).named("Default Animation"); + } + + Command infoAnimation() { + return run(Coroutine::yield).withPriority(0).named("Info"); + } + + Command warningAnimation() { + return run(Coroutine::yield).withPriority(1).named("Warning"); + } + }; + + Trigger normalPriorityEvent = new Trigger(() -> true); + Trigger highPriorityEvent = new Trigger(() -> true); + + StateMachine stateMachine = new StateMachine("State Machine"); + + var idleState = stateMachine.addState(leds.idleAnimation()); + var infoState = stateMachine.addState(leds.infoAnimation()); + var warningState = stateMachine.addState(leds.warningAnimation()); + + stateMachine.setInitialState(idleState); + + idleState.switchTo(infoState).when(normalPriorityEvent.and(highPriorityEvent.negate())); + idleState.switchTo(warningState).when(highPriorityEvent); + + warningState.switchTo(infoState).whenCompleteAnd(normalPriorityEvent); + infoState.switchTo(warningState).whenCompleteAnd(highPriorityEvent); + + stateMachine.switchFromAny().to(warningState).when(highPriorityEvent); + stateMachine.switchFromAny().to(idleState).whenComplete(); + } +} diff --git a/design-docs/commands-v3-state-machines.md b/design-docs/commands-v3-state-machines.md new file mode 100644 index 0000000000..df2c9eea74 --- /dev/null +++ b/design-docs/commands-v3-state-machines.md @@ -0,0 +1,273 @@ +# State Machines in WPILib Commands Version 3 + +- See [Commands v3](commands-v3.md) for details on the commands framework + +## Problem Statement {#problem-statement} + +Coroutines are a powerful way to express low- to high-complexity behaviors. However, they become unwieldy at +representing highly complex behaviors where phases may be repeated or skipped to at any point in the sequence. State +machines excel at this by providing ways to transition from any arbitrary state to any other arbitrary state, flattening +the declarative structure of a coroutine into a linear sequence of states and transitions. + +Example: consider a FRC game like 2022 Rapid React or 2017 Steamworks. The robot has a drivetrain, a hopper to store +balls, a turret to aim at a goal, and a flywheel shooter to launch balls at the goal. We want an autonomous mode to +drive to a known position on the field for optimal scoring, then aim at the goal, fire balls until the hopper is empty, +and finally play an LED animation to indicate the end of the autonomous sequence. If the robot is moved away from the +scoring location, the scoring portion of the sequence should stop and the robot should move back into position, and then +resume the scoring sequence. + +```java +public Command autoWithStateMachine() { + // Declare the state machine + StateMachine stateMachine = new StateMachine("Auto With State Machine"); + + // Define states + State getInPosition = stateMachine.addState(drivetrain.driveToScoringLocation()); + State aiming = stateMachine.addState(turret.aimAtGoal()); + State scoring = stateMachine.addState(shooter.fireOnce()); + State celebrating = stateMachine.addState(leds.celebrate()); + + // Set the initial state. Neglecting this will cause a runtime exception when the state machine starts. + stateMachine.setInitialState(getInPosition); + + // Switch to aiming when we reach the scoring location. + getInPosition.switchTo(aiming).whenComplete(); + // Set the swerve wheels in an X shape after reaching the scoring location to resist being pushed away. + getInPosition.onExit(() -> Scheduler.getDefault().fork(drivetrain.setX())); + + // Then start scoring once the turret is aimed at the goal. + aiming.switchTo(scoring).when(turret::aimedAtGoal); + + // Loop the scoring state as long as the hopper has a ball. + scoring.switchTo(scoring).whenCompleteAnd(() -> hopper.hasBall()); + + // Automatically interrupt any part of the aiming or scoring sequence if + // the robot is moved away from the scoring location and move back into position. + stateMachine.switchFromAny(aiming, scoring).to(getInPosition).when(atScoringLocation.negate()); + + // Start celebrating once the final ball has been scored. + scoring.switchTo(celebrating).whenCompleteAnd(() -> !hopper.hasBall()); + + return stateMachine; +} +``` + +```java +Command autoWithCoroutines() { + return Command.noRequirements().executing(coroutine -> { + // Automatically score while the robot is in scoring position. + // This will be canceled if the robot is bumped away from the scoring location. + atScoringLocation.whileTrue( + turret.aimAtGoal() + .andThen(shooter.fireOnce().repeatWhile(hopper::hasBall)) + .andThen(leds.celebrate()) + .withAutomaticName() + ); + // Move back into scoring position if the robot is bumped away from the scoring location. + atScoringLocation.onFalse(drivetrain.driveToScoringLocation()); + + coroutine.await(drivetrain.driveToScoringLocation()); + + // Park to allow the triggered commands to run in the background. + // We assume the command will be canceled at the end of the autonomous period. + coroutine.park(); + }).named("Auto With Coroutines"); +} +``` + +## Implementation Details {#implementation-details} + +### Overview + +The public entry point is `org.wpilib.commands3.StateMachine` which implements `Command`. + +Each state machine is named; `name()` returns the provided name for telemetry and debugging. State machine names are +specified in the constructor; there is no dedicated builder like `NeedsNameBuilderStage.named(...)` for regular +commands. + +`requirements()` returns an empty set; the machine itself does not own any Mechanism. The commands that back states may +have requirements, which will be inherited by the state machine while those states are active, just like a normal +command with nested children. + +### Constructing a state machine + +State machines are created with `new StateMachine(String)`. The name cannot be null. The `StateMachine` class is final +and cannot be subclassed; v1-style group creation that does setup in a subclass' constructor is not supported: + +```java +// Not allowed +class CustomStateMachine extends StateMachine { + // ... +} +``` + +States are defined with `addState(Command)`. Users need to manually wire states together using transition builders after +defining the states. A `State` object wraps the underlying command and is responsible for tracking the possible +transitions out of that state. + +State machines have no initial state, which must be set explicitly: + +```java +StateMachine stateMachine = new StateMachine("Example"); +State initialState = stateMachine.addState(...); +stateMachine.setInitialState(initialState); +``` + +`setInitialState` throws an `NullPointerException` if given a null input. It may be called multiple times to override +the initial state before running. + +`setInitialState` and all transitions require that both states belong to the same `StateMachine` +object; otherwise an `IllegalArgumentException` is thrown. + +### State Machine Loop + +A state machine is a regular command that manages the state lifecycle in its `run()` method. The lifecycle is managed by +a loop, where in each iteration the current state's command is scheduled, and then enters an inner loop that continues +to yield as long as the command is running, similar to `Coroutine.waitUntil`. However, the inner loop also checks for +state transitions before calling `yield()`; if a transition is determined to be active, the command is canceled and the +state moves to the transition's target state. The main loop is then restarted with the new state. + +If a state's command finishes without triggering any transitions, the state machine checks for state completions. If +a completion is found, the state machine immediately moves to the completion's target state, and (to prevent a potential +infinite loop) conditionally inserts a `yield()` at the very end of the loop in case the command was a one-shot. + +Pseudocode: + +``` +currentState = initialState + +state_loop: +while currentState is not null: + currentState.onEnter() + fork currentState.command + didYield = false + + while currentState.command is running + for each transition in currentState.transitions + if transition.check() + currentState.onExit() + currentState = transition.targetState + restart state_loop + + didYield = true + yield + + currentState.onExit() + currentState = currentState.completions.find(completion -> completion.active())?.targetState + + if didYield is false and currentState is not null + yield +``` + +Note that state completions are different from transitions: transitions are only active on rising edge, while state +completions are active on every loop iteration (but are only checked once, when the state exits). + +### Transitions + +There are two kinds of transitions: conditional transitions (checked while the originating state's +command is running) and completion transitions (taken after the originating state's command finishes +on its own, if no conditional transition was taken). + +Transitions are configured using a staged builder setup similar to command builders. The initial builder stage starts +with one or more originating states, then moves to a stage for specifying a target state (which may be null, indicating +that the state machine should exit), and finally to a stage for specifying the condition that triggers the transition. + +Transitions start from one or more non-null originating states, and end with a single target state (which may be null, +indicating that the state machine should exit), and a condition that triggers the transition. Transitions are stored on +the originating states, rather than being stored on the state machine (this simplifies the implementation). Each +originating state gets its own copy of the transition. + +Transitions can be defined starting from the state machine itself with `StateMachine.switchFromAny(...)`, or from a +specific starting state with `State.switchTo(...)`: + +``` +stateMachine.switchFromAny(state1, state2).to(state3).when(...) + +// Identical to: +state1.switchTo(state3).when(...) +state2.switchTo(state3).when(...) +``` + +Builders have private constructors, so they cannot be instantiated directly. Users must use the fluent builder factories +with `StateMachine.switchFromAny(...)` or `State.switchTo(...)`. + +Transitions must be resilient to commands that transition to themselves like `state.switchTo(state).when(...)`. If the +condition is naively checked, the state machine will enter an infinite loop before it would naturally yield (check +transition -> cancel command -> enter new state -> check transition -> ...). To avoid +this, the condition is checked with rising-edge logic so that the transition is only triggered once per internal loop +iteration: + +```java +class Transition { + State targetState; + BooleanSupplier signal; + boolean previousSignal; + + boolean shouldTransition() { + boolean currentSignal = signal.getAsBoolean(); + boolean signalBecameTrue = currentSignal && !previousSignal; + previousSignal = currentSignal; + return signalBecameTrue; + } +} +``` + +#### Exiting the state machine explicitly + +`State.switchTo(State)` is to be used for state-to-state transitions and cannot accept `null` as an input. To support +exiting the state machine explicitly, there are two additional methods available on the builders: + +`State.exitStateMachine()` builds a transition that bypasses the null check and makes the user's intent clear, versus +`State.switchTo(null)` which may be confusing. + +`TransitionNeedsTargetStage.exitStateMachine()` is a convenience method that returns a transition to `null`. + +``` +state.switchTo(null).when(...) // NullPointerException +state.exitStateMachine().when(...) // OK +stateMachine.switchFromAny(state1, state2).toExitStateMachine().when(...) +``` + +### Callbacks on state entry/exit + +States maintain a list of `Runnable` callbacks that are run when the state is entered, and a separate list of callbacks +that are run when the state is exited. Callbacks are run in the order they were added. + +Entry callbacks are run immediately after the state's command is forked, so it can see the current command. However, +one-shot commands will complete in the `fork` call, so entry callbacks will not see it. + +Exit callbacks are run immediately before canceling the state's command (if the exit was caused by a transition). For +states that complete without a transition, the exit callbacks are run immediately after the state's command finishes and +before the next state is selected. + +### Runtime semantics + +- When a state becomes active: + 1) The state's command is scheduled via `coroutine.fork(state.command)`. + 2) `onEnter` callbacks run. + 3) While the command is running each scheduler iteration: + - All conditional transitions from this state are checked in insertion order; each evaluates + `shouldTransition()`. The first to trigger wins. + - If a transition triggers: `onExit` callbacks run, the command is canceled, and the next state is set. The + machine immediately begins the next loop iteration with the new state without an extra yield; the next state's + command can start in the same scheduler run. If the next state is null, the machine exits. + - If no transition triggers: the coroutine yields once for this iteration. +- If the command stops running without any conditional transition firing: + - `onExit` callbacks run. + - The next state is selected from completion transitions in insertion order. If none match, the machine exits (next + state is null). + - To ensure fairness and prevent tight looping with one-shot commands, the machine guarantees at least one yield per + state. If the state command finished without ever yielding (one-shot), the machine yields once before starting the + next state's command (unless exiting). + +### Edge cases and guarantees + +- Self-transition is supported; the rising-edge guard ensures only a single re-entry per loop when the condition rises. + The exiting command is canceled and then immediately re-scheduled. +- One-shot commands should use completion transitions to continue the flow; conditional transitions cannot trigger for + them because the commands exit before conditional transitions can be checked. +- If multiple transitions are configured with the same condition on the same state, only the first will ever trigger in + a given loop iteration. +- Transitions cannot target states in a different state machine; an exception is thrown if attempted. +- The initial state must be set explicitly; otherwise the machine throws on first run and will not + remain scheduled. diff --git a/javacPlugin/src/main/java/org/wpilib/javacplugin/PostConstructionInitializerListener.java b/javacPlugin/src/main/java/org/wpilib/javacplugin/PostConstructionInitializerListener.java new file mode 100644 index 0000000000..233e92c25b --- /dev/null +++ b/javacPlugin/src/main/java/org/wpilib/javacplugin/PostConstructionInitializerListener.java @@ -0,0 +1,316 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.javacplugin; + +import com.sun.source.tree.AssignmentTree; +import com.sun.source.tree.BlockTree; +import com.sun.source.tree.CompilationUnitTree; +import com.sun.source.tree.ExpressionTree; +import com.sun.source.tree.MemberSelectTree; +import com.sun.source.tree.MethodInvocationTree; +import com.sun.source.tree.NewClassTree; +import com.sun.source.util.JavacTask; +import com.sun.source.util.TaskEvent; +import com.sun.source.util.TaskListener; +import com.sun.source.util.TreePath; +import com.sun.source.util.TreeScanner; +import com.sun.source.util.Trees; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.SequencedSet; +import java.util.Set; +import java.util.stream.Collectors; +import javax.lang.model.element.Element; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.TypeElement; +import javax.lang.model.element.VariableElement; +import javax.lang.model.type.DeclaredType; +import javax.lang.model.type.TypeKind; +import javax.tools.Diagnostic; +import org.wpilib.annotation.PostConstructionInitializer; + +/** + * Ensures methods tagged with {@link PostConstructionInitializer} are called after the owning + * object is constructed. + */ +public class PostConstructionInitializerListener implements TaskListener { + private final JavacTask m_task; + private final Set m_visitedCUs = new HashSet<>(); + + public PostConstructionInitializerListener(JavacTask task) { + m_task = task; + } + + @Override + public void finished(TaskEvent e) { + // We override `finished` instead of `started` because we want to run after the + // ANALYZE attribution phase has completed and assigned types to elements in the AST + // Track the visited CUs to avoid re-processing the same CU multiple times when we call + // `Trees.getElement()` on a tree path. + var compilationUnit = e.getCompilationUnit(); + if (e.getKind() == TaskEvent.Kind.ANALYZE && m_visitedCUs.add(compilationUnit)) { + var state = new State(); + compilationUnit.accept(new Scanner(compilationUnit), state); + + if (state.m_initializedObjects.isEmpty()) { + // Good! No partially initialized objects were detected. + return; + } + + var trees = Trees.instance(m_task); + + for (InitializedObject partiallyInitializedObject : state.m_initializedObjects.values()) { + var object = partiallyInitializedObject.object(); + var uncalledInitializers = partiallyInitializedObject.initializers(); + trees.printMessage( + Diagnostic.Kind.ERROR, + "Partially-initialized object `%s` is missing %s %s" + .formatted( + object.getSimpleName(), + uncalledInitializers.size() == 1 + ? "a call to initializer method" + : "calls to " + uncalledInitializers.size() + " initializer methods:", + uncalledInitializers.stream() + .map(i -> "`" + i.initializer().getSimpleName() + "()`") + .collect(Collectors.joining(", "))), + trees.getTree(object), + compilationUnit); + } + } + } + + /** + * Get all methods declared by the object's type, any supertypes, and any interfaces, filtering + * only those annotated with {@link PostConstructionInitializer}. + */ + private Set getRequiredInitializers(VariableElement object) { + var type = object.asType(); + + if (type.getKind() != TypeKind.DECLARED) { + return Set.of(); + } + + TypeElement typeElement = (TypeElement) ((DeclaredType) type).asElement(); + if (typeElement == null) { + return Set.of(); + } + + // Use a LinkedHashSet to maintain stable iteration order and deduplicate methods + SequencedSet methods = new LinkedHashSet<>(); + + // Elements#getAllMembers returns all members including inherited ones (classes + interfaces) + for (Element member : m_task.getElements().getAllMembers(typeElement)) { + if (member instanceof ExecutableElement method) { + if (method.getAnnotation(PostConstructionInitializer.class) == null) { + continue; + } + + methods.add(method); + } + } + + if (methods.isEmpty()) { + return Set.of(); + } + + Set result = new LinkedHashSet<>(); + for (ExecutableElement m : methods) { + result.add(new RequiredInitializer(m)); + } + + return result; + } + + private final class State { + private final Map m_initializedObjects = new HashMap<>(); + + void addMaybeInitializedObject(VariableElement object) { + var requiredInitializers = getRequiredInitializers(object); + if (requiredInitializers.isEmpty()) { + return; + } + m_initializedObjects.put(object, new InitializedObject(object, requiredInitializers)); + } + + boolean isTracking(VariableElement e) { + return m_initializedObjects.containsKey(e); + } + + void removeFullyInitializedObjects() { + m_initializedObjects + .values() + .removeIf(initializedObject -> initializedObject.initializers().isEmpty()); + } + + void removeInitializer(VariableElement object, ExecutableElement initializer) { + if (!m_initializedObjects.containsKey(object)) { + return; + } + + m_initializedObjects.get(object).initializers().removeIf(i -> i.is(initializer)); + } + + void merge(State otherState) { + if (otherState == null) { + return; + } + + otherState.m_initializedObjects.forEach( + (object, initializedObject) -> { + m_initializedObjects.putIfAbsent(object, initializedObject); + m_initializedObjects + .get(object) + .initializers() + .addAll(initializedObject.initializers()); + }); + } + } + + /** + * Tracks what initializer methods still need to be called for a given object. Elements are + * removed from the {@link #initializers} set as they are found; once the set is empty, the object + * is considered fully initialized and is removed from the tracking set. + * + * @param object The object to track + * @param initializers The set of initializer methods that still need to be called on the object. + * This is mutable! + */ + private record InitializedObject(VariableElement object, Set initializers) {} + + private record RequiredInitializer(ExecutableElement initializer) { + boolean is(ExecutableElement check) { + return initializer.equals(check); + } + } + + private final class Scanner extends TreeScanner { + private final CompilationUnitTree m_root; + private final Trees m_trees; + + Scanner(CompilationUnitTree compilationUnit) { + m_root = compilationUnit; + m_trees = Trees.instance(m_task); + } + + @Override + public State reduce(State r1, State r2) { + if (r1 == null) { + return r2; + } + r1.merge(r2); + r1.removeFullyInitializedObjects(); + return r1; + } + + @Override + public State visitBlock(BlockTree node, State localState) { + // Always operate on a non-null state + State workingState = localState != null ? localState : new State(); + + super.visitBlock(node, workingState); + + // Remove any objects that are now fully initialized within this block. + workingState.removeFullyInitializedObjects(); + return workingState; + } + + @Override + public State visitNewClass(NewClassTree node, State localState) { + // Always operate on a non-null state + State workingState = localState != null ? localState : new State(); + + TreePath path = m_trees.getPath(m_root, node); + + if (Suppressions.hasSuppression(m_trees, path, PostConstructionInitializer.SUPPRESSION_KEY)) { + // Warnings are suppressed in this context, ignore + return super.visitNewClass(node, workingState); + } + + var parentElement = m_trees.getElement(path.getParentPath()); + if (parentElement instanceof VariableElement v) { + workingState.addMaybeInitializedObject(v); + } else if (path.getParentPath().getLeaf() instanceof AssignmentTree assignment) { + var lhsElement = m_trees.getElement(m_trees.getPath(m_root, assignment.getVariable())); + if (lhsElement instanceof VariableElement v) { + workingState.addMaybeInitializedObject(v); + } + } + + super.visitNewClass(node, workingState); + return workingState; + } + + @Override + public State visitMethodInvocation(MethodInvocationTree node, State localState) { + // Always operate on a non-null state + State workingState = localState != null ? localState : new State(); + + TreePath path = m_trees.getPath(m_root, node); + var invokedElement = m_trees.getElement(path); + if (!(invokedElement instanceof ExecutableElement executableElement)) { + super.visitMethodInvocation(node, workingState); + return workingState; + } + + // The invoked method doesn't have our annotation, skip. It's not an initializer method. + if (executableElement.getAnnotation(PostConstructionInitializer.class) == null) { + super.visitMethodInvocation(node, workingState); + return workingState; + } + + if (node.getMethodSelect() instanceof MemberSelectTree variableTree) { + var element = m_trees.getElement(m_trees.getPath(m_root, variableTree.getExpression())); + switch (element) { + case VariableElement v when workingState.isTracking(v) -> { + workingState.removeInitializer(v, executableElement); + } + case TypeElement t -> { + // Static method call, check for a variable that's in scope that's passed to this + // method. If the method accepts multiple parameters of this type, then check for a + // parameter with the @PostConstructionInitializer.InitializedParam annotation and only + // look at the variable passed as that parameter. + List possibleParameters = + getAnnotatedParameters(executableElement, t); + + if (possibleParameters.size() != 1) { + // This condition is enforced by the annotation processor, which runs before this + // plugin. If there's an error with the setup, users will already see a compiler error + break; + } + + VariableElement param = possibleParameters.get(0); + // Find the argument at the same index as the parameter and, if it refers to a + // tracked variable/field, mark its initializer as called. + int paramIndex = executableElement.getParameters().indexOf(param); + if (paramIndex >= 0 && paramIndex < node.getArguments().size()) { + ExpressionTree argument = node.getArguments().get(paramIndex); + Element argElement = m_trees.getElement(m_trees.getPath(m_root, argument)); + if (argElement instanceof VariableElement v && workingState.isTracking(v)) { + workingState.removeInitializer(v, executableElement); + } + } + } + default -> { + // Ignore + } + } + } + workingState.removeFullyInitializedObjects(); + + super.visitMethodInvocation(node, workingState); + return workingState; + } + + private List getAnnotatedParameters( + ExecutableElement executableElement, TypeElement requiredType) { + return executableElement.getParameters().stream() + .filter(p -> p.asType().equals(requiredType.asType())) + .toList(); + } + } +} diff --git a/javacPlugin/src/main/java/org/wpilib/javacplugin/PostConstructionInitializerProcessor.java b/javacPlugin/src/main/java/org/wpilib/javacplugin/PostConstructionInitializerProcessor.java new file mode 100644 index 0000000000..c6e0a77676 --- /dev/null +++ b/javacPlugin/src/main/java/org/wpilib/javacplugin/PostConstructionInitializerProcessor.java @@ -0,0 +1,104 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.javacplugin; + +import java.util.Set; +import javax.annotation.processing.AbstractProcessor; +import javax.annotation.processing.RoundEnvironment; +import javax.lang.model.SourceVersion; +import javax.lang.model.element.Element; +import javax.lang.model.element.ExecutableElement; +import javax.lang.model.element.Modifier; +import javax.lang.model.element.Name; +import javax.lang.model.element.TypeElement; +import javax.lang.model.element.VariableElement; +import javax.tools.Diagnostic; +import org.wpilib.annotation.PostConstructionInitializer; +import org.wpilib.annotation.PostConstructionInitializer.InitializedParam; + +/** + * Sanity checks for {@link PostConstructionInitializer}-annotated methods. This does not check for + * usages of the annotated method; that is handled by the {@link + * PostConstructionInitializerListener} compiler plugin. + */ +public class PostConstructionInitializerProcessor extends AbstractProcessor { + @Override + public SourceVersion getSupportedSourceVersion() { + return SourceVersion.latestSupported(); + } + + @Override + public Set getSupportedAnnotationTypes() { + return Set.of( + "org.wpilib.annotation.PostConstructionInitializer", + "org.wpilib.annotation.PostConstructionInitializer.InitializedParam"); + } + + @Override + public boolean process(Set annotations, RoundEnvironment roundEnv) { + var annotatedElements = roundEnv.getElementsAnnotatedWith(PostConstructionInitializer.class); + + for (Element element : annotatedElements) { + // Check static initializer methods. + // Static initializers must either take exactly one parameter of the type of the class they're + // in, or have exactly one parameter of the type that's annotated with @InitializedParam + if (element instanceof ExecutableElement exec + && exec.getModifiers().contains(Modifier.STATIC) + && exec.getEnclosingElement() instanceof TypeElement type) { + Name typeName = type.getQualifiedName(); + + var typedParameters = + exec.getParameters().stream() + .filter( + p -> + processingEnv + .getTypeUtils() + .isSameType(p.asType(), exec.getEnclosingElement().asType())) + .toList(); + + switch (typedParameters.size()) { + case 0 -> printErrorForNoParams(exec, typeName); + case 1 -> { + // No ambiguity + } + default -> { + // Multiple parameters. + // Require exactly one with a @PostConstructionInitializer.InitializedParam annotation, + // for disambiguation + var taggedParameters = + typedParameters.stream() + .filter(p -> p.getAnnotation(InitializedParam.class) != null) + .toList(); + if (taggedParameters.isEmpty()) { + printTaggedParameterCountError(exec, typeName); + } else if (taggedParameters.size() > 1) { + for (VariableElement taggedParameter : taggedParameters) { + printTaggedParameterCountError(taggedParameter, typeName); + } + } + } + } + } + } + + return false; + } + + private void printErrorForNoParams(Element errorNode, Name typeName) { + String message = + "Static @PostConstructionInitializer method must take a parameter of type " + typeName; + + processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, message, errorNode); + } + + private void printTaggedParameterCountError(Element errorNode, Name typeName) { + String message = + "Static @PostConstructionInitializer method must take exactly one parameter of type " + + typeName + + " with a @PostConstructionInitializer.InitializedParam annotation"; + + processingEnv.getMessager().printMessage(Diagnostic.Kind.ERROR, message, errorNode); + } +} diff --git a/javacPlugin/src/main/java/org/wpilib/javacplugin/ReturnValueUsedListener.java b/javacPlugin/src/main/java/org/wpilib/javacplugin/ReturnValueUsedListener.java index 4481e933fa..3c24693bbf 100644 --- a/javacPlugin/src/main/java/org/wpilib/javacplugin/ReturnValueUsedListener.java +++ b/javacPlugin/src/main/java/org/wpilib/javacplugin/ReturnValueUsedListener.java @@ -74,23 +74,8 @@ public class ReturnValueUsedListener implements TaskListener { private void checkIgnoredExpression(Tree node) { var path = m_trees.getPath(m_root, node); - // Walk the tree upwards to see if the node is directly or indirectly annotated with - // @SuppressWarnings("NoDiscard") or @SuppressWarnings("all"). If so, then we ignore any - // @NoDiscard messages for this node - for (var currentPath = path; currentPath != null; currentPath = currentPath.getParentPath()) { - var element = m_trees.getElement(currentPath); - if (element == null) { - continue; - } - - if (element.getAnnotation(SuppressWarnings.class) != null) { - String[] suppressions = element.getAnnotation(SuppressWarnings.class).value(); - for (String suppression : suppressions) { - if ("NoDiscard".equals(suppression) || "all".equals(suppression)) { - return; - } - } - } + if (Suppressions.hasSuppression(m_trees, path, "NoDiscard")) { + return; } var parentPath = (path == null) ? null : path.getParentPath(); diff --git a/javacPlugin/src/main/java/org/wpilib/javacplugin/WPILibJavacPlugin.java b/javacPlugin/src/main/java/org/wpilib/javacplugin/WPILibJavacPlugin.java index 0a8d5874ba..52a93a08eb 100644 --- a/javacPlugin/src/main/java/org/wpilib/javacplugin/WPILibJavacPlugin.java +++ b/javacPlugin/src/main/java/org/wpilib/javacplugin/WPILibJavacPlugin.java @@ -19,6 +19,7 @@ public class WPILibJavacPlugin implements Plugin { @Override public void init(JavacTask task, String... args) { + task.addTaskListener(new PostConstructionInitializerListener(task)); task.addTaskListener(new ReturnValueUsedListener(task)); task.addTaskListener(new MaxLengthDetector(task)); task.addTaskListener(new OpModeAnnotationValidator(task)); diff --git a/javacPlugin/src/main/resources/META-INF/services/javax.annotation.processing.Processor b/javacPlugin/src/main/resources/META-INF/services/javax.annotation.processing.Processor new file mode 100644 index 0000000000..80c4712ba4 --- /dev/null +++ b/javacPlugin/src/main/resources/META-INF/services/javax.annotation.processing.Processor @@ -0,0 +1 @@ +org.wpilib.javacplugin.PostConstructionInitializerProcessor diff --git a/javacPlugin/src/test/java/org/wpilib/javacplugin/PostConstructionInitializerListenerTest.java b/javacPlugin/src/test/java/org/wpilib/javacplugin/PostConstructionInitializerListenerTest.java new file mode 100644 index 0000000000..1c25f47d36 --- /dev/null +++ b/javacPlugin/src/test/java/org/wpilib/javacplugin/PostConstructionInitializerListenerTest.java @@ -0,0 +1,506 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.javacplugin; + +import static com.google.testing.compile.CompilationSubject.assertThat; +import static com.google.testing.compile.Compiler.javac; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.javacplugin.CompileTestUtils.kJavaVersionOptions; + +import com.google.testing.compile.Compilation; +import com.google.testing.compile.JavaFileObjects; +import org.junit.jupiter.api.Test; + +class PostConstructionInitializerListenerTest { + @Test + void instanceInitializerIsUsed() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + + static void usage() { + var example = new Example(); + example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void instanceInitializerIsNotUsed() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + + static void usage() { + Example example = new Example(); + // example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Partially-initialized object `example` is missing a call to initializer method `init()`", + error.getMessage(null)); + } + + @Test + void instanceInitializerIsUsedInFactory() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + + static Example makeExample() { + var example = new Example(); + example.init(); + return example; + } + + static void usage() { + var example = makeExample(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void instanceInitializerCalledInInnerBlock() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + + static void usage() { + var example = new Example(); + if (false) { + // Will never actually run, but the plugin doesn't ignore dead branches + example.init(); + } + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void instanceInitializerInConstructorDoesNotCount() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + public Example() { + init(); + } + + @PostConstructionInitializer + void init() { } + + static void usage() { + var example = new Example(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Partially-initialized object `example` is missing a call to initializer method `init()`", + error.getMessage(null)); + } + + @Test + void staticInitializerIsUsed() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + static void init(Example e) { } + + static void usage() { + var example = new Example(); + Example.init(example); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void staticInitializerIsNotUsed() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + static void init(Example e) { } + + static void usage() { + var example = new Example(); + // Example.init(example); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Partially-initialized object `example` is missing a call to initializer method `init()`", + error.getMessage(null)); + } + + @Test + void checksForInitializersFromInterfaces() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + interface I1 { + @PostConstructionInitializer + default void i1Init() {} + } + + interface I2 extends I1 { + @PostConstructionInitializer + default void i2Init() {} + } + + class Example implements I2 { + @PostConstructionInitializer + void init() { } + + static void usage() { + var example = new Example(); + example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Partially-initialized object `example` is missing calls to 2 initializer methods: " + + "`i1Init()`, `i2Init()`", + error.getMessage(null)); + } + + @Test + void initializerCalledInOtherContext() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + } + + class User { + Example example = new Example(); + + void later() { + example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void initializerCalledAfterConstructor() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + } + + class User { + Example example; + + User() { + example = new Example(); + } + + void later() { + example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void initializerNotCalledAfterConstructor() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + } + + class User { + Example example; + + User() { + example = new Example(); + } + + void later() { + // example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Partially-initialized object `example` is missing a call to initializer method `init()`", + error.getMessage(null)); + } + + @Test + void initializerCalledOnAccessorAfterConstructor() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + } + + class User { + Example example; + + User() { + example = new Example(); + } + + Example getExample() { + return example; + } + + void later() { + // The plugin can't detect calls from accessor methods. + // Initializers MUST be called on the variable directly. + getExample().init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Partially-initialized object `example` is missing a call to initializer method `init()`", + error.getMessage(null)); + } + + @Test + void suppressWarningsOnConstructorCall() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + + static void usage() { + @SuppressWarnings("PostConstructionInitializer") + var example = new Example(); + // example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void suppressWarningsOnCallerMethod() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + void init() { } + + @SuppressWarnings("PostConstructionInitializer") + static void usage() { + var example = new Example(); + // example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void suppressWarningsOnClass() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + @SuppressWarnings("PostConstructionInitializer") + class Example { + @PostConstructionInitializer + void init() { } + + static void usage() { + var example = new Example(); + // example.init(); + } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } +} diff --git a/javacPlugin/src/test/java/org/wpilib/javacplugin/PostConstructionInitializerProcessorTest.java b/javacPlugin/src/test/java/org/wpilib/javacplugin/PostConstructionInitializerProcessorTest.java new file mode 100644 index 0000000000..4b9f9536d8 --- /dev/null +++ b/javacPlugin/src/test/java/org/wpilib/javacplugin/PostConstructionInitializerProcessorTest.java @@ -0,0 +1,184 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.javacplugin; + +import static com.google.testing.compile.CompilationSubject.assertThat; +import static com.google.testing.compile.Compiler.javac; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.wpilib.javacplugin.CompileTestUtils.kJavaVersionOptions; + +import com.google.testing.compile.Compilation; +import com.google.testing.compile.JavaFileObjects; +import org.junit.jupiter.api.Test; + +class PostConstructionInitializerProcessorTest { + @Test + void staticInitializerWithNoParameters() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + static void init() { } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .withProcessors(new PostConstructionInitializerProcessor()) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Static @PostConstructionInitializer method must take a parameter of type " + + "frc.robot.Example", + error.getMessage(null)); + } + + @Test + void staticInitializerWithOneParameter() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + + class Example { + @PostConstructionInitializer + static void init(Example e) { } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .withProcessors(new PostConstructionInitializerProcessor()) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void staticInitializerWithOneAnnotatedParameter() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + import org.wpilib.annotation.PostConstructionInitializer.InitializedParam; + + class Example { + @PostConstructionInitializer + static void init(@InitializedParam Example e) { } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .withProcessors(new PostConstructionInitializerProcessor()) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void staticInitializerWithOneAnnotatedParameterWithUnannotatedParameter() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + import org.wpilib.annotation.PostConstructionInitializer.InitializedParam; + + class Example { + @PostConstructionInitializer + static void init(@InitializedParam Example dst, Example src) { } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .withProcessors(new PostConstructionInitializerProcessor()) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).succeededWithoutWarnings(); + } + + @Test + void staticInitializerWithMultipleAnnotatedParameters() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + import org.wpilib.annotation.PostConstructionInitializer.InitializedParam; + + class Example { + @PostConstructionInitializer + static void init(@InitializedParam Example a, @InitializedParam Example b) { } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .withProcessors(new PostConstructionInitializerProcessor()) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(2, compilation.errors().size()); + var error1 = compilation.errors().get(0); + assertEquals( + "Static @PostConstructionInitializer method must take exactly one parameter of type " + + "frc.robot.Example with a @PostConstructionInitializer.InitializedParam annotation", + error1.getMessage(null)); + + var error2 = compilation.errors().get(1); + assertEquals( + "Static @PostConstructionInitializer method must take exactly one parameter of type " + + "frc.robot.Example with a @PostConstructionInitializer.InitializedParam annotation", + error2.getMessage(null)); + } + + @Test + void staticInitializerAcceptingBaseType() { + String source = + """ + package frc.robot; + + import org.wpilib.annotation.PostConstructionInitializer; + import org.wpilib.annotation.PostConstructionInitializer.InitializedParam; + + class Base {} + + class Example extends Base { + @PostConstructionInitializer + static void init(@InitializedParam Base e) { } + } + """; + + Compilation compilation = + javac() + .withOptions(kJavaVersionOptions) + .withProcessors(new PostConstructionInitializerProcessor()) + .compile(JavaFileObjects.forSourceString("frc.robot.Example", source)); + + assertThat(compilation).failed(); + assertEquals(1, compilation.errors().size()); + var error = compilation.errors().get(0); + assertEquals( + "Static @PostConstructionInitializer method must take a parameter of type " + + "frc.robot.Example", + error.getMessage(null)); + } +} diff --git a/wpiannotations/src/main/java/org/wpilib/annotation/PostConstructionInitializer.java b/wpiannotations/src/main/java/org/wpilib/annotation/PostConstructionInitializer.java new file mode 100644 index 0000000000..61980a6d08 --- /dev/null +++ b/wpiannotations/src/main/java/org/wpilib/annotation/PostConstructionInitializer.java @@ -0,0 +1,91 @@ +// Copyright (c) FIRST and other WPILib contributors. +// Open Source Software; you can modify and/or share it under the terms of +// the WPILib BSD license file in the root directory of this project. + +package org.wpilib.annotation; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Marks a method as a post-construction initializer. The WPILib compiler plugin will check for uses + * of methods with this annotation and report a compiler error if the method is not called after the + * object is constructed. + * + *

Limitations of this annotation: + * + *

    + *
  • Initializer methods must be called on the variable directly. They cannot be detected if + * called indirectly (e.g., on an object returned by a method) + *
    {@code
    + * // This is OK
    + * Foo foo = new Foo();
    + * foo.init();
    + *
    + * // This is not OK
    + * Box box = new Box(new Foo());
    + * box.getFoo().init();
    + *
    + * }
    + *
  • Static initializer methods must accept exactly one parameter of the type that defines the + * static method (they cannot accept a parameter of a supertype or derived type). + *
  • Static initializer methods with multiple parameters of the initialized type must annotate + * one of them with {@link InitializedParam} to disambiguate for the compiler. + *
+ * + *

Errors reported by the compiler plugin may be suppressed by annotating the offending method + * with {@code SuppressWarnings("PostConstructionInitializer")} or {@code + * SuppressWarnings(PostConstructionInitializer.SUPPRESSION_KEY)}. This is intended to be used in + * tests to allow runtime error handling code to be tested, but may also be used to suppress + * spurious warnings in production code. + */ +@Target(ElementType.METHOD) +@Retention(RetentionPolicy.RUNTIME) +public @interface PostConstructionInitializer { + /** + * The string key to use in {@link SuppressWarnings} annotations to suppress compiler error + * messages related to this annotation. + */ + String SUPPRESSION_KEY = "PostConstructionInitializer"; + + /** + * Marks a specific parameter in a static initializer method as being the initialized object. This + * disambiguates situations where static initializer methods accept multiple arguments of the same + * initialize-required type; for example: + * + *

{@code
+   * class Foo {
+   *   @PostConstructionInitializer
+   *   static void copy(Foo src, @InitializedParam Foo dst) {
+   *     // ...
+   *   }
+   * }
+   * }
+ * + *

Static initializer methods must have a parameter of the exact type that defines the static + * method. + * + *

{@code
+   * interface I {
+   *   @PostConstructionInitializer
+   *   static void init(I object) { ... }
+   * }
+   *
+   * class Foo implements I {
+   *   @PostConstructionInitializer
+   *   static void initFoo(Foo foo) { ... } // OK
+   *
+   *   @PostConstructionInitializer
+   *   static void initI(I object) { ... } // ERROR: I is not Foo
+   *
+   *   @PostConstructionInitializer
+   *   static void initOther(SomeOtherType o) { ... } // ERROR: SomeOtherType is not Foo
+   * }
+   * }
+ */ + @Target(ElementType.PARAMETER) + @Retention(RetentionPolicy.RUNTIME) + @interface InitializedParam {} +}