[wpimath] Fix crash in KF latency compensator (#3888)

It would crash in C++ if the global measurement was sooner than all the
snapshots.

Align Java with the changes and better document computation approach.
This commit is contained in:
Tyler Veness
2022-01-09 23:01:04 -08:00
committed by GitHub
parent a3a0334fad
commit ba0908216c
2 changed files with 75 additions and 31 deletions

View File

@@ -74,15 +74,18 @@ public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O exte
return;
}
// This index starts at one because we use the previous state later on, and we always want to
// have a "previous state".
int maxIdx = m_pastObserverSnapshots.size() - 1;
int low = 1;
int high = Math.max(maxIdx, 1);
// Use a less verbose name for timestamp
double timestamp = timestampSeconds;
int maxIdx = m_pastObserverSnapshots.size() - 1;
int low = 0;
int high = maxIdx;
// Perform a binary search to find the index of first snapshot whose
// timestamp is greater than or equal to the global measurement timestamp
while (low != high) {
int mid = (low + high) / 2;
if (m_pastObserverSnapshots.get(mid).getKey() < timestampSeconds) {
if (m_pastObserverSnapshots.get(mid).getKey() < timestamp) {
// This index and everything under it are less than the requested timestamp. Therefore, we
// can discard them.
low = mid + 1;
@@ -93,16 +96,37 @@ public class KalmanFilterLatencyCompensator<S extends Num, I extends Num, O exte
}
}
// We are simply assigning this index to a new variable to avoid confusion
// with variable names.
int index = low;
double timestamp = timestampSeconds;
int indexOfClosestEntry =
Math.abs(timestamp - m_pastObserverSnapshots.get(index - 1).getKey())
<= Math.abs(
timestamp - m_pastObserverSnapshots.get(Math.min(index, maxIdx)).getKey())
? index - 1
: index;
int indexOfClosestEntry;
if (low == 0) {
// If the global measurement is older than any snapshot, throw out the
// measurement because there's no state estimate into which to incorporate
// the measurement
if (timestamp < m_pastObserverSnapshots.get(low).getKey()) {
return;
}
// If the first snapshot has same timestamp as the global measurement, use
// that snapshot
indexOfClosestEntry = 0;
} else if (low == maxIdx && m_pastObserverSnapshots.get(low).getKey() < timestamp) {
// If all snapshots are older than the global measurement, use the newest
// snapshot
indexOfClosestEntry = maxIdx;
} else {
// Index of snapshot taken after the global measurement
int nextIdx = low;
// Index of snapshot taken before the global measurement. Since we already
// handled the case where the index points to the first snapshot, this
// computation is guaranteed to be nonnegative.
int prevIdx = nextIdx - 1;
// Find the snapshot closest in time to global measurement
double prevTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(prevIdx).getKey());
double nextTimeDiff = Math.abs(timestamp - m_pastObserverSnapshots.get(nextIdx).getKey());
indexOfClosestEntry = prevTimeDiff <= nextTimeDiff ? prevIdx : nextIdx;
}
double lastTimestamp =
m_pastObserverSnapshots.get(indexOfClosestEntry).getKey() - nominalDtSeconds;

View File

@@ -86,26 +86,46 @@ class KalmanFilterLatencyCompensator {
return;
}
// We will perform a binary search to find the index of the element in the
// vector that has a timestamp that is equal to or greater than the vision
// measurement timestamp.
auto lowerBoundIter = std::lower_bound(
// Perform a binary search to find the index of first snapshot whose
// timestamp is greater than or equal to the global measurement timestamp
auto it = std::lower_bound(
m_pastObserverSnapshots.cbegin(), m_pastObserverSnapshots.cend(),
timestamp,
[](const auto& entry, const auto& ts) { return entry.first < ts; });
int index = std::distance(m_pastObserverSnapshots.cbegin(), lowerBoundIter);
// High and Low should be the same. The sampled timestamp is greater than or
// equal to the vision pose timestamp. We will now find the entry which is
// closest in time to the requested timestamp.
size_t indexOfClosestEntry;
size_t indexOfClosestEntry =
units::math::abs(
timestamp - m_pastObserverSnapshots[std::max(index - 1, 0)].first) <
units::math::abs(timestamp -
m_pastObserverSnapshots[index].first)
? index - 1
: index;
if (it == m_pastObserverSnapshots.cbegin()) {
// If the global measurement is older than any snapshot, throw out the
// measurement because there's no state estimate into which to incorporate
// the measurement
if (timestamp < it->first) {
return;
}
// If the first snapshot has same timestamp as the global measurement, use
// that snapshot
indexOfClosestEntry = 0;
} else if (it == m_pastObserverSnapshots.cend()) {
// If all snapshots are older than the global measurement, use the newest
// snapshot
indexOfClosestEntry = m_pastObserverSnapshots.size() - 1;
} else {
// Index of snapshot taken after the global measurement
int nextIdx = std::distance(m_pastObserverSnapshots.cbegin(), it);
// Index of snapshot taken before the global measurement. Since we already
// handled the case where the index points to the first snapshot, this
// computation is guaranteed to be nonnegative.
int prevIdx = nextIdx - 1;
// Find the snapshot closest in time to global measurement
units::second_t prevTimeDiff =
units::math::abs(timestamp - m_pastObserverSnapshots[prevIdx].first);
units::second_t nextTimeDiff =
units::math::abs(timestamp - m_pastObserverSnapshots[nextIdx].first);
indexOfClosestEntry = prevTimeDiff < nextTimeDiff ? prevIdx : nextIdx;
}
units::second_t lastTimestamp =
m_pastObserverSnapshots[indexOfClosestEntry].first - nominalDt;