Merge branch 'main' into 2027

This commit is contained in:
Peter Johnson
2025-09-20 11:19:40 -07:00
99 changed files with 3296 additions and 2039 deletions

View File

@@ -3,7 +3,10 @@
{
"aql": {
"items.find": {
"repo": "wpilib-mvn-development-local",
"$or":[
{ "repo": "wpilib-mvn-development-local" },
{ "repo": "wpilib-mvn-development-2027-local" }
],
"path": { "$nmatch":"*edu/wpi/first/thirdparty*" },
"$or":[
{

View File

@@ -46,6 +46,12 @@ jobs:
- name: configure
run: cmake --preset with-sccache -DCMAKE_BUILD_TYPE=RelWithDebInfo -DWITH_WPILIB=OFF -DWITH_GUI=OFF -DWITH_CSCORE=OFF -DWITH_TESTS=OFF -DWITH_SIMULATION_MODULES=OFF -DWITH_JAVA=ON -DBUILD_SHARED_LIBS=ON -DCMAKE_TOOLCHAIN_FILE=${{ steps.setup-ndk.outputs.ndk-path }}/build/cmake/android.toolchain.cmake -DANDROID_ABI="${{ matrix.abi }}" -DANDROID_PLATFORM=android-24
env:
SCCACHE_WEBDAV_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }}
SCCACHE_WEBDAV_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }}
- name: build
run: cmake --build build-cmake --parallel $(nproc)
env:
SCCACHE_WEBDAV_USERNAME: ${{ secrets.ARTIFACTORY_USERNAME }}
SCCACHE_WEBDAV_PASSWORD: ${{ secrets.ARTIFACTORY_PASSWORD }}

View File

@@ -102,6 +102,12 @@ jobs:
architecture: aarch64
task: "build"
outputs: "build/allOutputs"
- os: windows-2022
artifact-name: Win64FFI
architecture: x64
task: ":ntcoreffi:build"
build-options: "-Pntcoreffibuild -Pbuildwinarm64"
outputs: "ntcoreffi/build/outputs"
name: "Build - ${{ matrix.artifact-name }}"
runs-on: ${{ matrix.os }}
needs: [validation]

View File

@@ -51,39 +51,6 @@ Have an idea to make WPILib better? Here's some steps to go from idea to impleme
WPILib uses modified Google style guides for both C++ and Java, which can be found in the [styleguide repository](https://github.com/wpilibsuite/styleguide). Autoformatters are available for many popular editors at https://github.com/google/styleguide. Running wpiformat is required for all contributions and is enforced by our continuous integration system.
While the library should be fully formatted according to the styles, additional elements of the style guide were not followed when the library was initially created. All new code should follow the guidelines. If you are looking for some easy ramp-up tasks, finding areas that don't follow the style guide and fixing them is very welcome.
### Math documentation
When writing math expressions in documentation, use https://www.unicodeit.net/ to convert LaTeX to a Unicode equivalent that's easier to read. Not all expressions will translate (e.g., superscripts of superscripts) so focus on making it readable by someone who isn't familiar with LaTeX. If content on multiple lines needs to be aligned in Doxygen/Javadoc comments (e.g., integration/summation limits, matrices packed with square brackets and superscripts for them), put them in @verbatim/@endverbatim blocks in Doxygen or `<pre>` tags in Javadoc so they render with monospace font.
The LaTeX to Unicode conversions can also be done locally via the unicodeit Python package. To install it, execute:
```bash
pip install --user unicodeit
```
Here's example usage:
```bash
$ python -m unicodeit.cli 'x_{k+1} = Ax_k + Bu_k'
xₖ₊₁ = Axₖ + Buₖ
```
On Linux, this process can be streamlined further by adding the following Bash function to your .bashrc (requires `wl-clipboard` on Wayland or `xclip` on X11):
```bash
# Converts LaTeX to Unicode, prints the result, and copies it to the clipboard
uc() {
if [ $WAYLAND_DISPLAY ]; then
python -m unicodeit.cli $@ | tee >(wl-copy -n)
else
python -m unicodeit.cli $@ | tee >(xclip -sel)
fi
}
```
Here's example usage:
```bash
$ uc 'x_{k+1} = Ax_k + Bu_k'
xₖ₊₁ = Axₖ + Buₖ
```
## Submitting Changes
### Pull Request Format

View File

@@ -178,12 +178,12 @@ load("@rules_jvm_external//:defs.bzl", "maven_install")
load("@rules_jvm_external//:specs.bzl", "maven")
maven_artifacts = [
"org.ejml:ejml-simple:0.43.1",
"com.fasterxml.jackson.core:jackson-annotations:2.15.2",
"com.fasterxml.jackson.core:jackson-core:2.15.2",
"com.fasterxml.jackson.core:jackson-databind:2.15.2",
"us.hebi.quickbuf:quickbuf-runtime:1.3.3",
"com.google.code.gson:gson:2.10.1",
"org.ejml:ejml-simple:0.44.0",
"com.fasterxml.jackson.core:jackson-annotations:2.19.2",
"com.fasterxml.jackson.core:jackson-core:2.19.2",
"com.fasterxml.jackson.core:jackson-databind:2.19.2",
"us.hebi.quickbuf:quickbuf-runtime:1.4",
"com.google.code.gson:gson:2.13.1",
"edu.wpi.first.thirdparty.frc2025.opencv:opencv-java:4.10.0-3",
maven.artifact(
"org.junit.jupiter",

View File

@@ -165,6 +165,7 @@ doxygen.sourceSets.main {
tasks.register("zipCppDocs", Zip) {
archiveBaseName = zipBaseNameCpp
archiveVersion = ""
destinationDirectory = outputsFolder
dependsOn doxygenDox
from ("$buildDir/docs/doxygen/html")
@@ -203,6 +204,7 @@ task generateJavaDocs(type: Javadoc) {
"-edu.wpi.first.math.system.plant.proto," +
"-edu.wpi.first.math.system.plant.struct," +
"-edu.wpi.first.math.trajectory.proto," +
"-edu.wpi.first.math.trajectory.struct," +
// The .measure package contains generated source files for which automatic javadoc
// generation is very difficult to do meaningfully.
"-edu.wpi.first.units.measure", true)
@@ -243,6 +245,7 @@ task generateJavaDocs(type: Javadoc) {
tasks.register("zipJavaDocs", Zip) {
archiveBaseName = zipBaseNameJava
archiveVersion = ""
destinationDirectory = outputsFolder
dependsOn generateJavaDocs
from ("$buildDir/docs/javadoc")

View File

@@ -6,6 +6,7 @@ package edu.wpi.first.epilogue.processor;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.ArrayType;
import javax.lang.model.type.PrimitiveType;
import javax.lang.model.type.TypeMirror;
@@ -52,7 +53,7 @@ public class ArrayHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
public String logInvocation(Element element, TypeElement loggedClass) {
var dataType = dataType(element);
// known to be an array type (assuming isLoggable is checked first); this is a safe cast
@@ -63,13 +64,17 @@ public class ArrayHandler extends ElementHandler {
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element)
+ elementAccess(element, loggedClass)
+ ", "
+ m_structHandler.structAccess(componentType)
+ ")";
} else {
// Primitive or string array
return "backend.log(\"" + loggedName(element) + "\", " + elementAccess(element) + ")";
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element, loggedClass)
+ ")";
}
}
}

View File

@@ -6,6 +6,7 @@ package edu.wpi.first.epilogue.processor;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeMirror;
@@ -38,7 +39,7 @@ public class CollectionHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
public String logInvocation(Element element, TypeElement loggedClass) {
var dataType = dataType(element);
var componentType = ((DeclaredType) dataType).getTypeArguments().get(0);
@@ -46,12 +47,16 @@ public class CollectionHandler extends ElementHandler {
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element)
+ elementAccess(element, loggedClass)
+ ", "
+ m_structHandler.structAccess(componentType)
+ ")";
} else {
return "backend.log(\"" + loggedName(element) + "\", " + elementAccess(element) + ")";
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element, loggedClass)
+ ")";
}
}
}

View File

@@ -7,6 +7,7 @@ package edu.wpi.first.epilogue.processor;
import java.util.Map;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.DeclaredType;
import javax.lang.model.type.TypeMirror;
@@ -27,7 +28,7 @@ public class ConfiguredLoggerHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
public String logInvocation(Element element, TypeElement loggedClass) {
var dataType = dataType(element);
var loggerType =
m_customLoggers.entrySet().stream()
@@ -44,7 +45,7 @@ public class ConfiguredLoggerHandler extends ElementHandler {
+ ".tryUpdate(backend.getNested(\""
+ loggedName(element)
+ "\"), "
+ elementAccess(element)
+ elementAccess(element, loggedClass)
+ ", Epilogue.getConfig().errorHandler)";
}
}

View File

@@ -117,9 +117,9 @@ public abstract class ElementHandler {
* @param element the element to generate the access for
* @return the generated access snippet
*/
public String elementAccess(Element element) {
public String elementAccess(Element element, TypeElement loggedClass) {
if (element instanceof VariableElement field) {
return fieldAccess(field);
return fieldAccess(field, loggedClass);
} else if (element instanceof ExecutableElement method) {
return methodAccess(method);
} else {
@@ -127,8 +127,20 @@ public abstract class ElementHandler {
}
}
private static String fieldAccess(VariableElement field) {
if (!field.getModifiers().contains(Modifier.PUBLIC)) {
private static String fieldAccess(VariableElement field, TypeElement loggedClass) {
var mods = field.getModifiers();
// To be directly accessible, the field needs to be:
// - public; or
// - protected or package-private, and declared by a superclass in the same package
// However, we can't cleanly access package information, so we'll always emit a VarHandle
// for any field declared in a superclass unless it's public and we know we can read it.
boolean isVarHandle =
field.getEnclosingElement().equals(loggedClass)
? mods.contains(Modifier.PRIVATE)
: !mods.contains(Modifier.PUBLIC);
if (isVarHandle) {
// ((com.example.Foo) $fooField.get(object))
// Extra parentheses so cast evaluates before appended methods
// (e.g. when appending .getAsDouble())
@@ -136,7 +148,7 @@ public abstract class ElementHandler {
if (type.getKind() == TypeKind.TYPEVAR) {
type = ((TypeVariable) type).getUpperBound();
}
return "((" + type.toString() + ") $" + field.getSimpleName() + ".get(object))";
return "((" + type.toString() + ") " + LoggerGenerator.varHandleName(field) + ".get(object))";
} else {
// object.fooField
return "object." + field.getSimpleName();
@@ -171,5 +183,5 @@ public abstract class ElementHandler {
* @param element the field or method element to generate the logger call for
* @return the generated log invocation
*/
public abstract String logInvocation(Element element);
public abstract String logInvocation(Element element, TypeElement loggedClass);
}

View File

@@ -6,6 +6,7 @@ package edu.wpi.first.epilogue.processor;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.TypeMirror;
public class EnumHandler extends ElementHandler {
@@ -27,7 +28,11 @@ public class EnumHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
return "backend.log(\"" + loggedName(element) + "\", " + elementAccess(element) + ")";
public String logInvocation(Element element, TypeElement loggedClass) {
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element, loggedClass)
+ ")";
}
}

View File

@@ -39,7 +39,7 @@ public class LoggableHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
public String logInvocation(Element element, TypeElement loggedClass) {
TypeMirror dataType = dataType(element);
var declaredType =
m_processingEnv
@@ -61,7 +61,7 @@ public class LoggableHandler extends ElementHandler {
// If there are no known loggable subtypes, return just the single logger call
if (size == 1) {
return generateLoggerCall(element, declaredType, elementAccess(element));
return generateLoggerCall(element, declaredType, elementAccess(element, loggedClass));
}
// Otherwise, generate an if-else chain to compare the element with its known loggable subtypes
@@ -73,7 +73,7 @@ public class LoggableHandler extends ElementHandler {
StringBuilder builder = new StringBuilder();
// Cache the value in a variable so it's only read once
builder.append("var %s = %s;\n".formatted(varName, elementAccess(element)));
builder.append("var %s = %s;\n".formatted(varName, elementAccess(element, loggedClass)));
for (int i = 0; i < size; i++) {
TypeElement type = loggableSubtypes.get(i);

View File

@@ -18,9 +18,11 @@ import java.io.PrintWriter;
import java.lang.annotation.Annotation;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.Deque;
import java.util.EnumMap;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
@@ -185,7 +187,21 @@ public class LoggerGenerator {
var loggerFile = m_processingEnv.getFiler().createSourceFile(loggerClassName);
var varHandleFields =
loggableFields.stream().filter(e -> !e.getModifiers().contains(Modifier.PUBLIC)).toList();
loggableFields.stream()
.filter(
e -> {
if (e.getEnclosingElement().equals(clazz)) {
// The generated logger is in the same package as the logged class, so the
// only fields it can't read are private ones.
return e.getModifiers().contains(Modifier.PRIVATE);
} else {
// Logging from a superclass. Can only read public fields, unless the superclass
// is in the same package, in which case protected and package-private fields
// are also readable.
return !e.getModifiers().contains(Modifier.PUBLIC);
}
})
.toList();
boolean requiresVarHandles = !varHandleFields.isEmpty();
try (var out = new PrintWriter(loggerFile.openWriter())) {
@@ -214,41 +230,67 @@ public class LoggerGenerator {
+ "> {");
if (requiresVarHandles) {
for (var varHandleField : varHandleFields) {
for (var privateField : varHandleFields) {
// This field needs a VarHandle to access.
// Cache it in the class to avoid lookups
out.println(" private static final VarHandle $" + varHandleField.getSimpleName() + ";");
out.printf(
" // Accesses private or superclass field %s.%s%n",
privateField.getEnclosingElement(), privateField.getSimpleName());
out.printf(" private static final VarHandle %s;%n", varHandleName(privateField));
}
out.println();
}
var classReference = simpleClassName + ".class";
// Static initializer block to load VarHandles and reflection fields
if (requiresVarHandles) {
out.println(" static {");
out.println(" try {");
out.println(
" var lookup = MethodHandles.privateLookupIn("
+ classReference
+ ", MethodHandles.lookup());");
for (var varHandleField : varHandleFields) {
var fieldName = varHandleField.getSimpleName();
out.println(
" $"
+ fieldName
+ " = lookup.findVarHandle("
+ classReference
+ ", \""
+ fieldName
+ "\", "
+ m_processingEnv.getTypeUtils().erasure(varHandleField.asType())
+ ".class);");
}
out.println(" try {");
out.println(" var rootLookup = MethodHandles.lookup();");
// Group private fields by class, then generate a private lookup for each class
// and a VarHandle for each field using that lookup. Sorting and then collecting into a
// LinkedHashMap gives deterministic output ordering (mostly for tests, which check exact
// file contents, but also results in less churn when regenerating files for users who like
// to read the generated logger classes).
//
// This lets us read private fields from superclasses.
Map<Element, List<VariableElement>> privateFieldsByClass =
varHandleFields.stream()
.sorted(Comparator.comparing(e -> e.getSimpleName().toString()))
.collect(
Collectors.groupingBy(
VariableElement::getEnclosingElement,
LinkedHashMap::new,
Collectors.toList()));
privateFieldsByClass.forEach(
(enclosingClass, fields) -> {
String className = enclosingClass.toString();
String lookupName = "lookup$$" + className.replace(".", "_");
out.printf(
" var %s = MethodHandles.privateLookupIn(%s.class, rootLookup);%n",
lookupName, className);
for (var field : fields) {
var fieldname = field.getSimpleName();
out.printf(
" %s = %s.findVarHandle(%s.class, \"%s\", %s.class);%n",
varHandleName(field),
lookupName,
className,
fieldname,
m_processingEnv.getTypeUtils().erasure(field.asType()));
}
});
out.println(" } catch (ReflectiveOperationException e) {");
out.println(
" throw new RuntimeException("
+ "\"[EPILOGUE] Could not load private fields for logging!\", e);");
out.println(" }");
out.println(" }");
out.println();
}
@@ -300,7 +342,7 @@ public class LoggerGenerator {
// to be logged. For example, the sendable handler consumes all sendable types
// but does not log commands or subsystems, to prevent excessive warnings about
// unloggable commands.
var logInvocation = h.logInvocation(loggableElement);
var logInvocation = h.logInvocation(loggableElement, clazz);
if (logInvocation != null) {
out.println(logInvocation.indent(6).stripTrailing() + ";");
}
@@ -315,6 +357,18 @@ public class LoggerGenerator {
}
}
/**
* Generates the name of a VarHandle for access to the given field. The VarHandle variable's name
* is guaranteed to be unique.
*
* @param field The field to generate a VarHandle for
* @return The name of the generated VarHandle variable
*/
public static String varHandleName(VariableElement field) {
return "$%s_%s"
.formatted(field.getEnclosingElement().toString().replace(".", "_"), field.getSimpleName());
}
private void collectLoggables(
TypeElement clazz, List<VariableElement> fields, List<ExecutableElement> methods) {
var config = clazz.getAnnotation(Logged.class);

View File

@@ -6,6 +6,7 @@ package edu.wpi.first.epilogue.processor;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.TypeMirror;
public class MeasureHandler extends ElementHandler {
@@ -30,8 +31,12 @@ public class MeasureHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
public String logInvocation(Element element, TypeElement loggedClass) {
// EpilogueBackend has builtin support for logging measures
return "backend.log(\"" + loggedName(element) + "\", " + elementAccess(element) + ")";
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element, loggedClass)
+ ")";
}
}

View File

@@ -16,6 +16,7 @@ import static javax.lang.model.type.TypeKind.SHORT;
import java.util.Set;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.TypeMirror;
public class PrimitiveHandler extends ElementHandler {
@@ -35,7 +36,11 @@ public class PrimitiveHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
return "backend.log(\"" + loggedName(element) + "\", " + elementAccess(element) + ")";
public String logInvocation(Element element, TypeElement loggedClass) {
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element, loggedClass)
+ ")";
}
}

View File

@@ -44,7 +44,7 @@ public class SendableHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
public String logInvocation(Element element, TypeElement loggedClass) {
var dataType = dataType(element);
// Do not log commands or subsystems via their sendable implementations
@@ -66,7 +66,7 @@ public class SendableHandler extends ElementHandler {
return "logSendable(backend.getNested(\""
+ loggedName(element)
+ "\"), "
+ elementAccess(element)
+ elementAccess(element, loggedClass)
+ ")";
}
}

View File

@@ -6,6 +6,7 @@ package edu.wpi.first.epilogue.processor;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.TypeMirror;
import javax.lang.model.util.Types;
@@ -38,11 +39,11 @@ public class StructHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
public String logInvocation(Element element, TypeElement loggedClass) {
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element)
+ elementAccess(element, loggedClass)
+ ", "
+ structAccess(dataType(element))
+ ")";

View File

@@ -6,6 +6,7 @@ package edu.wpi.first.epilogue.processor;
import javax.annotation.processing.ProcessingEnvironment;
import javax.lang.model.element.Element;
import javax.lang.model.element.TypeElement;
import javax.lang.model.type.TypeMirror;
public class SupplierHandler extends ElementHandler {
@@ -42,15 +43,19 @@ public class SupplierHandler extends ElementHandler {
}
@Override
public String logInvocation(Element element) {
return "backend.log(\"" + loggedName(element) + "\", " + elementAccess(element) + ")";
public String logInvocation(Element element, TypeElement loggedClass) {
return "backend.log(\""
+ loggedName(element)
+ "\", "
+ elementAccess(element, loggedClass)
+ ")";
}
@Override
public String elementAccess(Element element) {
public String elementAccess(Element element, TypeElement loggedClass) {
var typeUtils = m_processingEnv.getTypeUtils();
var dataType = dataType(element);
String base = super.elementAccess(element);
String base = super.elementAccess(element, loggedClass);
if (typeUtils.isAssignable(dataType, m_booleanSupplier)) {
return base + ".getAsBoolean()";

View File

@@ -105,14 +105,13 @@ public abstract class ClassSpecificLogger<T> {
return;
}
var builder =
m_sendables.computeIfAbsent(
sendable,
s -> {
var b = new LogBackedSendableBuilder(backend);
s.initSendable(b);
return b;
});
builder.update();
if (m_sendables.containsKey(sendable)) {
m_sendables.get(sendable).update();
} else {
var builder = new LogBackedSendableBuilder(backend);
sendable.initSendable(builder);
m_sendables.put(sendable, builder);
builder.update();
}
}
}

View File

@@ -23,7 +23,9 @@ import edu.wpi.first.datalog.StructArrayLogEntry;
import edu.wpi.first.datalog.StructLogEntry;
import edu.wpi.first.util.struct.Struct;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
/** A backend implementation that saves information to a WPILib {@link DataLog} file on disk. */
@@ -31,6 +33,7 @@ public class FileBackend implements EpilogueBackend {
private final DataLog m_dataLog;
private final Map<String, DataLogEntry> m_entries = new HashMap<>();
private final Map<String, NestedBackend> m_subLoggers = new HashMap<>();
private final Set<Struct<?>> m_seenSchemas = new HashSet<>();
/**
* Creates a new file-based backend.
@@ -43,7 +46,13 @@ public class FileBackend implements EpilogueBackend {
@Override
public EpilogueBackend getNested(String path) {
return m_subLoggers.computeIfAbsent(path, k -> new NestedBackend(k, this));
if (!m_subLoggers.containsKey(path)) {
var nested = new NestedBackend(path, this);
m_subLoggers.put(path, nested);
return nested;
}
return m_subLoggers.get(path);
}
@SuppressWarnings("unchecked")
@@ -131,14 +140,30 @@ public class FileBackend implements EpilogueBackend {
@Override
@SuppressWarnings("unchecked")
public <S> void log(String identifier, S value, Struct<S> struct) {
m_dataLog.addSchema(struct);
getEntry(identifier, (log, k) -> StructLogEntry.create(log, k, struct)).append(value);
// DataLog.addSchema has checks that we're able to skip, avoiding allocations
if (m_seenSchemas.add(struct)) {
m_dataLog.addSchema(struct);
}
if (!m_entries.containsKey(identifier)) {
m_entries.put(identifier, StructLogEntry.create(m_dataLog, identifier, struct));
}
((StructLogEntry<S>) m_entries.get(identifier)).append(value);
}
@Override
@SuppressWarnings("unchecked")
public <S> void log(String identifier, S[] value, Struct<S> struct) {
m_dataLog.addSchema(struct);
getEntry(identifier, (log, k) -> StructArrayLogEntry.create(log, k, struct)).append(value);
// DataLog.addSchema has checks that we're able to skip, avoiding allocations
if (m_seenSchemas.add(struct)) {
m_dataLog.addSchema(struct);
}
if (!m_entries.containsKey(identifier)) {
m_entries.put(identifier, StructArrayLogEntry.create(m_dataLog, identifier, struct));
}
((StructArrayLogEntry<S>) m_entries.get(identifier)).append(value);
}
}

View File

@@ -40,7 +40,13 @@ public class LazyBackend implements EpilogueBackend {
@Override
public EpilogueBackend getNested(String path) {
return m_subLoggers.computeIfAbsent(path, k -> new NestedBackend(k, this));
if (!m_subLoggers.containsKey(path)) {
var nested = new NestedBackend(path, this);
m_subLoggers.put(path, nested);
return nested;
}
return m_subLoggers.get(path);
}
@Override

View File

@@ -24,7 +24,13 @@ public class MultiBackend implements EpilogueBackend {
@Override
public EpilogueBackend getNested(String path) {
return m_nestedBackends.computeIfAbsent(path, k -> new NestedBackend(k, this));
if (!m_nestedBackends.containsKey(path)) {
var nested = new NestedBackend(path, this);
m_nestedBackends.put(path, nested);
return nested;
}
return m_nestedBackends.get(path);
}
@Override

View File

@@ -21,7 +21,10 @@ import edu.wpi.first.networktables.StructArrayPublisher;
import edu.wpi.first.networktables.StructPublisher;
import edu.wpi.first.util.struct.Struct;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
/**
* A backend implementation that sends data over network tables. Be careful when using this, since
@@ -32,61 +35,81 @@ public class NTEpilogueBackend implements EpilogueBackend {
private final Map<String, Publisher> m_publishers = new HashMap<>();
private final Map<String, NestedBackend> m_nestedBackends = new HashMap<>();
private final Set<Struct<?>> m_seenSchemas = new HashSet<>();
private final Function<String, IntegerPublisher> m_createIntPublisher;
private final Function<String, FloatPublisher> m_createFloatPublisher;
private final Function<String, DoublePublisher> m_createDoublePublisher;
private final Function<String, BooleanPublisher> m_createBooleanPublisher;
private final Function<String, RawPublisher> m_createRawPublisher;
private final Function<String, IntegerArrayPublisher> m_createIntegerArrayPublisher;
private final Function<String, FloatArrayPublisher> m_createFloatArrayPublisher;
private final Function<String, DoubleArrayPublisher> m_createDoubleArrayPublisher;
private final Function<String, BooleanArrayPublisher> m_createBooleanArrayPublisher;
private final Function<String, StringPublisher> m_createStringPublisher;
private final Function<String, StringArrayPublisher> m_createStringArrayPublisher;
/**
* Creates a logging backend that sends information to NetworkTables.
*
* @param nt the NetworkTable instance to use to send data to
*/
@SuppressWarnings("unchecked")
public NTEpilogueBackend(NetworkTableInstance nt) {
this.m_nt = nt;
m_createIntPublisher = identifier -> m_nt.getIntegerTopic(identifier).publish();
m_createFloatPublisher = identifier -> m_nt.getFloatTopic(identifier).publish();
m_createDoublePublisher = identifier -> m_nt.getDoubleTopic(identifier).publish();
m_createBooleanPublisher = identifier -> m_nt.getBooleanTopic(identifier).publish();
m_createRawPublisher = identifier -> m_nt.getRawTopic(identifier).publish("raw");
m_createIntegerArrayPublisher = identifier -> m_nt.getIntegerArrayTopic(identifier).publish();
m_createFloatArrayPublisher = identifier -> m_nt.getFloatArrayTopic(identifier).publish();
m_createDoubleArrayPublisher = identifier -> m_nt.getDoubleArrayTopic(identifier).publish();
m_createBooleanArrayPublisher = identifier -> m_nt.getBooleanArrayTopic(identifier).publish();
m_createStringPublisher = identifier -> m_nt.getStringTopic(identifier).publish();
m_createStringArrayPublisher = identifier -> m_nt.getStringArrayTopic(identifier).publish();
}
@Override
public EpilogueBackend getNested(String path) {
return m_nestedBackends.computeIfAbsent(path, k -> new NestedBackend(k, this));
if (!m_nestedBackends.containsKey(path)) {
var nested = new NestedBackend(path, this);
m_nestedBackends.put(path, nested);
return nested;
}
return m_nestedBackends.get(path);
}
@Override
public void log(String identifier, int value) {
((IntegerPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getIntegerTopic(k).publish()))
.set(value);
((IntegerPublisher) m_publishers.computeIfAbsent(identifier, m_createIntPublisher)).set(value);
}
@Override
public void log(String identifier, long value) {
((IntegerPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getIntegerTopic(k).publish()))
.set(value);
((IntegerPublisher) m_publishers.computeIfAbsent(identifier, m_createIntPublisher)).set(value);
}
@Override
public void log(String identifier, float value) {
((FloatPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getFloatTopic(k).publish()))
.set(value);
((FloatPublisher) m_publishers.computeIfAbsent(identifier, m_createFloatPublisher)).set(value);
}
@Override
public void log(String identifier, double value) {
((DoublePublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getDoubleTopic(k).publish()))
((DoublePublisher) m_publishers.computeIfAbsent(identifier, m_createDoublePublisher))
.set(value);
}
@Override
public void log(String identifier, boolean value) {
((BooleanPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getBooleanTopic(k).publish()))
((BooleanPublisher) m_publishers.computeIfAbsent(identifier, m_createBooleanPublisher))
.set(value);
}
@Override
public void log(String identifier, byte[] value) {
((RawPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getRawTopic(k).publish("raw")))
.set(value);
((RawPublisher) m_publishers.computeIfAbsent(identifier, m_createRawPublisher)).set(value);
}
@Override
@@ -100,68 +123,79 @@ public class NTEpilogueBackend implements EpilogueBackend {
}
((IntegerArrayPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getIntegerArrayTopic(k).publish()))
m_publishers.computeIfAbsent(identifier, m_createIntegerArrayPublisher))
.set(widened);
}
@Override
public void log(String identifier, long[] value) {
((IntegerArrayPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getIntegerArrayTopic(k).publish()))
m_publishers.computeIfAbsent(identifier, m_createIntegerArrayPublisher))
.set(value);
}
@Override
public void log(String identifier, float[] value) {
((FloatArrayPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getFloatArrayTopic(k).publish()))
((FloatArrayPublisher) m_publishers.computeIfAbsent(identifier, m_createFloatArrayPublisher))
.set(value);
}
@Override
public void log(String identifier, double[] value) {
((DoubleArrayPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getDoubleArrayTopic(k).publish()))
((DoubleArrayPublisher) m_publishers.computeIfAbsent(identifier, m_createDoubleArrayPublisher))
.set(value);
}
@Override
public void log(String identifier, boolean[] value) {
((BooleanArrayPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getBooleanArrayTopic(k).publish()))
m_publishers.computeIfAbsent(identifier, m_createBooleanArrayPublisher))
.set(value);
}
@Override
public void log(String identifier, String value) {
((StringPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getStringTopic(k).publish()))
((StringPublisher) m_publishers.computeIfAbsent(identifier, m_createStringPublisher))
.set(value);
}
@Override
public void log(String identifier, String[] value) {
((StringArrayPublisher)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getStringArrayTopic(k).publish()))
((StringArrayPublisher) m_publishers.computeIfAbsent(identifier, m_createStringArrayPublisher))
.set(value);
}
@Override
@SuppressWarnings("unchecked")
public <S> void log(String identifier, S value, Struct<S> struct) {
m_nt.addSchema(struct);
((StructPublisher<S>)
m_publishers.computeIfAbsent(identifier, k -> m_nt.getStructTopic(k, struct).publish()))
.set(value);
// NetworkTableInstance.addSchema has checks that we're able to skip, avoiding allocations
if (m_seenSchemas.add(struct)) {
m_nt.addSchema(struct);
}
if (m_publishers.containsKey(identifier)) {
((StructPublisher<S>) m_publishers.get(identifier)).set(value);
} else {
StructPublisher<S> publisher = m_nt.getStructTopic(identifier, struct).publish();
m_publishers.put(identifier, publisher);
publisher.set(value);
}
}
@Override
@SuppressWarnings("unchecked")
public <S> void log(String identifier, S[] value, Struct<S> struct) {
m_nt.addSchema(struct);
((StructArrayPublisher<S>)
m_publishers.computeIfAbsent(
identifier, k -> m_nt.getStructArrayTopic(k, struct).publish()))
.set(value);
// NetworkTableInstance.addSchema has checks that we're able to skip, avoiding allocations
if (m_seenSchemas.add(struct)) {
m_nt.addSchema(struct);
}
if (m_publishers.containsKey(identifier)) {
((StructArrayPublisher<S>) m_publishers.get(identifier)).set(value);
} else {
StructArrayPublisher<S> publisher = m_nt.getStructArrayTopic(identifier, struct).publish();
m_publishers.put(identifier, publisher);
publisher.set(value);
}
}
}

View File

@@ -17,6 +17,15 @@ public class NestedBackend implements EpilogueBackend {
private final EpilogueBackend m_impl;
private final Map<String, NestedBackend> m_nestedBackends = new HashMap<>();
// String concatenation can be expensive, especially for deeply nested hierarchies with many
// logged fields. For example, logging a hypothetical Robot.elevator.io.getHeight() would result
// in "/Robot/" + "elevator/" + "io/" + "getHeight"; three concatenations and string and byte
// array allocations that need to be cleaned up by the GC. Caching the results means those
// allocations only occur once, resulting in no GC (the strings are always referenced in the
// cache), and minimal time costs (the String object caches its own hash code, so all we do is an
// O(1) table lookup per concatenation)
private final Map<String, String> m_prefixedIdentifiers = new HashMap<>();
/**
* Creates a new nested backed underneath another backend.
*
@@ -33,83 +42,109 @@ public class NestedBackend implements EpilogueBackend {
this.m_impl = impl;
}
/**
* Fast lookup to avoid redundant `m_prefix + identifier` concatenations. If the identifier has
* not been seen before, we compute the concatenation and cache the result for later invocations
* to read. This avoids redundantly recomputing the same concatenations every loop and
* significantly cuts down on the CPU and memory overhead of the Epilogue library.
*
* @param identifier The identifier to prepend with {@link #m_prefix}.
* @return The concatenated string.
*/
private String withPrefix(String identifier) {
// Using computeIfAbsent would result in a new lambda object allocation on every call
if (m_prefixedIdentifiers.containsKey(identifier)) {
return m_prefixedIdentifiers.get(identifier);
}
String result = m_prefix + identifier;
m_prefixedIdentifiers.put(identifier, result);
return result;
}
@Override
public EpilogueBackend getNested(String path) {
return m_nestedBackends.computeIfAbsent(path, k -> new NestedBackend(k, this));
if (!m_nestedBackends.containsKey(path)) {
var nested = new NestedBackend(path, this);
m_nestedBackends.put(path, nested);
return nested;
}
return m_nestedBackends.get(path);
}
@Override
public void log(String identifier, int value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, long value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, float value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, double value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, boolean value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, byte[] value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, int[] value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, long[] value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, float[] value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, double[] value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, boolean[] value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, String value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public void log(String identifier, String[] value) {
m_impl.log(m_prefix + identifier, value);
m_impl.log(withPrefix(identifier), value);
}
@Override
public <S> void log(String identifier, S value, Struct<S> struct) {
m_impl.log(m_prefix + identifier, value, struct);
m_impl.log(withPrefix(identifier), value, struct);
}
@Override
public <S> void log(String identifier, S[] value, Struct<S> struct) {
m_impl.log(m_prefix + identifier, value, struct);
m_impl.log(withPrefix(identifier), value, struct);
}
}

View File

@@ -0,0 +1,180 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.epilogue.logging;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertSame;
import org.junit.jupiter.api.Test;
class NestedBackendTest {
@Test
void prefixesAppliedAndNested() {
var root = new TestBackend();
var nested = new NestedBackend("/Robot", root);
nested.log("int", 1);
nested.log("string", "hello");
var arm = nested.getNested("arm");
arm.log("position", 2.0);
arm.log("enabled", true);
assertEquals(4, root.getEntries().size());
assertEquals("/Robot/int", root.getEntries().get(0).identifier());
assertEquals(1, root.getEntries().get(0).value());
assertEquals("/Robot/string", root.getEntries().get(1).identifier());
assertEquals("hello", root.getEntries().get(1).value());
assertEquals("/Robot/arm/position", root.getEntries().get(2).identifier());
assertEquals(2.0, root.getEntries().get(2).value());
assertEquals("/Robot/arm/enabled", root.getEntries().get(3).identifier());
assertEquals(true, root.getEntries().get(3).value());
}
@Test
void handlesTrailingSlashOnPrefix() {
var root = new TestBackend();
var a = new NestedBackend("/Robot", root);
var b = new NestedBackend("/Robot/", root);
a.log("x", 1);
b.log("y", 2);
assertEquals("/Robot/x", root.getEntries().get(0).identifier());
assertEquals("/Robot/y", root.getEntries().get(1).identifier());
}
@Test
void getNestedIsCached() {
var root = new TestBackend();
var nested = new NestedBackend("/Robot", root);
var arm1 = nested.getNested("arm");
var arm2 = nested.getNested("arm");
assertSame(arm1, arm2);
}
@Test
void usesPrefixedIdentifierCacheForSameField() {
var root = new TestBackend();
var nested = new NestedBackend("/Robot", root);
// Same field logged multiple times - identifier object should be the same (cached)
// We use assertSame to check that the references are identical
nested.log("x", 0);
nested.log("x", 1);
String id0 = root.getEntries().get(0).identifier();
String id1 = root.getEntries().get(1).identifier();
assertSame(
id0,
id1,
"Identifier %s (id: %d) was not reused (new id: %d)"
.formatted(id0, System.identityHashCode(id0), System.identityHashCode(id1)));
// Also verify through a nested backend path
var arm = nested.getNested("arm");
arm.log("position", 0.0);
arm.log("position", 1.0);
String id2 = root.getEntries().get(2).identifier();
String id3 = root.getEntries().get(3).identifier();
assertSame(
id2,
id3,
"Identifier %s (id: %d) was not reused (new id: %d)"
.formatted(id2, System.identityHashCode(id2), System.identityHashCode(id3)));
// Sanity check actual full values
assertEquals("/Robot/x", id0);
assertEquals("/Robot/arm/position", id2);
}
@Test
void logsAllOverloads() {
var root = new TestBackend();
var nested = new NestedBackend("/Robot", root);
// Scalars
nested.log("int", 1);
nested.log("long", 2L);
nested.log("float", 3.0f);
nested.log("double", 4.0);
nested.log("boolean", true);
nested.log("string", "hello");
// Arrays
nested.log("bytes", new byte[] {1, 2});
nested.log("ints", new int[] {3, 4});
nested.log("longs", new long[] {5L, 6L});
nested.log("floats", new float[] {7.0f, 8.0f});
nested.log("doubles", new double[] {9.0, 10.0});
nested.log("booleans", new boolean[] {true, false});
nested.log("strings", new String[] {"x", "y"});
// Structs
nested.log("customStruct", new CustomStruct(7), CustomStruct.struct);
nested.log(
"customStructs",
new CustomStruct[] {new CustomStruct(0), new CustomStruct(1)},
CustomStruct.struct);
var entries = root.getEntries();
int idx = 0;
// Scalars
assertEquals(new TestBackend.LogEntry<>("/Robot/int", 1), entries.get(idx++));
assertEquals(new TestBackend.LogEntry<>("/Robot/long", 2L), entries.get(idx++));
assertEquals(new TestBackend.LogEntry<>("/Robot/float", 3.0f), entries.get(idx++));
assertEquals(new TestBackend.LogEntry<>("/Robot/double", 4.0), entries.get(idx++));
assertEquals(new TestBackend.LogEntry<>("/Robot/boolean", true), entries.get(idx++));
assertEquals(new TestBackend.LogEntry<>("/Robot/string", "hello"), entries.get(idx++));
// Arrays
assertEquals("/Robot/bytes", entries.get(idx).identifier());
assertArrayEquals(new byte[] {1, 2}, (byte[]) entries.get(idx++).value());
assertEquals("/Robot/ints", entries.get(idx).identifier());
assertArrayEquals(new int[] {3, 4}, (int[]) entries.get(idx++).value());
assertEquals("/Robot/longs", entries.get(idx).identifier());
assertArrayEquals(new long[] {5L, 6L}, (long[]) entries.get(idx++).value());
assertEquals("/Robot/floats", entries.get(idx).identifier());
assertArrayEquals(new float[] {7.0f, 8.0f}, (float[]) entries.get(idx++).value());
assertEquals("/Robot/doubles", entries.get(idx).identifier());
assertArrayEquals(new double[] {9.0, 10.0}, (double[]) entries.get(idx++).value());
assertEquals("/Robot/booleans", entries.get(idx).identifier());
assertArrayEquals(new boolean[] {true, false}, (boolean[]) entries.get(idx++).value());
assertEquals("/Robot/strings", entries.get(idx).identifier());
assertArrayEquals(new String[] {"x", "y"}, (String[]) entries.get(idx++).value());
// Structs are serialized to bytes
assertEquals("/Robot/customStruct", entries.get(idx).identifier());
assertArrayEquals(new byte[] {0x07, 0x00, 0x00, 0x00}, (byte[]) entries.get(idx++).value());
assertEquals("/Robot/customStructs", entries.get(idx).identifier());
// two int32 values, little-endian
assertArrayEquals(
new byte[] {
0x00, 0x00, 0x00, 0x00, // 0 (first element)
0x01, 0x00, 0x00, 0x00, // 1 (second element)
0x00, 0x00, 0x00, 0x00, // 0 (empty space allocated by StructBuffer)
0x00, 0x00, 0x00, 0x00 // 0 (empty space allocated by StructBuffer)
},
(byte[]) entries.get(idx++).value());
// Ensure we covered all calls
assertEquals(idx, entries.size());
}
}

View File

@@ -35,10 +35,10 @@ public final class CANAPITypes {
kGyroSensor(4),
/** Accelerometer. */
kAccelerometer(5),
/** Ultrasonic sensor. */
kUltrasonicSensor(6),
/** Gear tooth sensor. */
kGearToothSensor(7),
/** Distance sensor. */
kDistanceSensor(6),
/** Encoder. */
kEncoder(7),
/** Power distribution. */
kPowerDistribution(8),
/** Pneumatics. */
@@ -49,6 +49,8 @@ public final class CANAPITypes {
kIOBreakout(11),
/** Servo Controller. */
kServoController(12),
/** Color Sensor. */
kColorSensor(13),
/** Firmware update. */
kFirmwareUpdate(31);
@@ -105,7 +107,15 @@ public final class CANAPITypes {
/** AndyMark. */
kAndyMark(15),
/** Vivid-Hosting. */
kVividHosting(16);
kVividHosting(16),
/** Vertos Robotics. */
kVertosRobotics(17),
/** SWYFT Robotics. */
kSWYFTRobotics(18),
/** Lumyn Labs. */
kLumynLabs(19),
/** Brushland Labs. */
kBrushlandLabs(20);
/** The manufacturer ID. */
@SuppressWarnings("MemberName")

View File

@@ -32,10 +32,10 @@ HAL_ENUM(HAL_CANDeviceType) {
HAL_CAN_Dev_kGyroSensor = 4,
/// Accelerometer.
HAL_CAN_Dev_kAccelerometer = 5,
/// Ultrasonic sensor.
HAL_CAN_Dev_kUltrasonicSensor = 6,
/// Gear tooth sensor.
HAL_CAN_Dev_kGearToothSensor = 7,
/// Distance sensor.
HAL_CAN_Dev_kDistanceSensor = 6,
/// Encoder.
HAL_CAN_Dev_kEncoder = 7,
/// Power distribution.
HAL_CAN_Dev_kPowerDistribution = 8,
/// Pneumatics.
@@ -44,8 +44,10 @@ HAL_ENUM(HAL_CANDeviceType) {
HAL_CAN_Dev_kMiscellaneous = 10,
/// IO breakout.
HAL_CAN_Dev_kIOBreakout = 11,
// Servo controller.
/// Servo controller.
HAL_CAN_Dev_kServoController = 12,
/// Color Sensor.
HAL_CAN_Dev_ColorSensor = 13,
/// Firmware update.
HAL_CAN_Dev_kFirmwareUpdate = 31
};
@@ -89,7 +91,15 @@ HAL_ENUM(HAL_CANManufacturer) {
/// AndyMark.
HAL_CAN_Man_kAndyMark = 15,
/// Vivid-Hosting.
HAL_CAN_Man_kVividHosting = 16
HAL_CAN_Man_kVividHosting = 16,
/// Vertos Robotics.
HAL_CAN_Man_kVertosRobotics = 17,
/// SWYFT Robotics.
HAL_CAN_Man_kSWYFTRobotics = 18,
/// Lumyn Labs.
HAL_CAN_Man_kLumynLabs = 19,
/// Brushland Labs
HAL_CAN_Man_kBrushlandLabs = 20
};
/**

View File

@@ -42,6 +42,7 @@ model {
}
from(applicationPath)
into(nativeUtils.getPlatformPath(binary))
}
task.dependsOn binary.tasks.link

View File

@@ -38,6 +38,10 @@ def eigen_inclusions(dp: Path, f: str):
if "MKL" in f:
return False
# Exclude HIP CUDA support
if "GpuHip" in f:
return False
# Include architectures we care about by filtering for Core/arch
if "Core" in dp.parts and "arch" in dp.parts:
return (
@@ -140,8 +144,8 @@ def copy_upstream_src(wpilib_root: Path):
def main():
name = "eigen"
url = "https://gitlab.com/libeigen/eigen.git"
# master on 2025-05-18
tag = "d81aa18f4dc56264b2cd7e2f230807d776a2d385"
# master on 2025-09-08
tag = "e0a59e5a66e6d16fa93ab4f5e48bf539205e837f"
eigen = Lib(name, url, tag, copy_upstream_src)
eigen.main()

View File

@@ -8,7 +8,7 @@ Subject: [PATCH 2/2] Intellisense fix
1 file changed, 7 insertions(+)
diff --git a/Eigen/src/Core/util/ConfigureVectorization.h b/Eigen/src/Core/util/ConfigureVectorization.h
index 49f307c734e937f013e659e931286a17ef6756f9..a9430716a320327aed81ea0cdffabc051aeb0ce2 100644
index c2546a083898154a1d4bd741722a5544cbdb1d92..8b5cc16b2092a73804af87ed7f59722ae3fdab0c 100644
--- a/Eigen/src/Core/util/ConfigureVectorization.h
+++ b/Eigen/src/Core/util/ConfigureVectorization.h
@@ -178,6 +178,13 @@

View File

@@ -98,6 +98,15 @@ public interface Subsystem {
CommandScheduler.getInstance().registerSubsystem(this);
}
/**
* Constructs a command that does nothing until interrupted. Requires this subsystem.
*
* @return the command
*/
default Command idle() {
return Commands.idle(this);
}
/**
* Constructs a command that runs an action once and finishes. Requires this subsystem.
*

View File

@@ -46,6 +46,10 @@ void Subsystem::Register() {
return CommandScheduler::GetInstance().RegisterSubsystem(this);
}
CommandPtr Subsystem::Idle() {
return cmd::Idle({this});
}
CommandPtr Subsystem::RunOnce(std::function<void()> action) {
return cmd::RunOnce(std::move(action), {this});
}

View File

@@ -121,6 +121,15 @@ class Subsystem {
*/
void Register();
/**
* Constructs a command that does nothing until interrupted. Requires this
* subsystem.
*
* @return the command
*/
[[nodiscard]]
CommandPtr Idle();
/**
* Constructs a command that runs an action once and finishes. Requires this
* subsystem.

View File

@@ -21,39 +21,39 @@ file(
if(WITH_JAVA)
include(UseJava)
if(NOT EXISTS "${WPILIB_BINARY_DIR}/wpimath/thirdparty/ejml/ejml-simple-0.43.1.jar")
if(NOT EXISTS "${WPILIB_BINARY_DIR}/wpimath/thirdparty/ejml/ejml-simple-0.44.0.jar")
set(BASE_URL "https://search.maven.org/remotecontent?filepath=")
set(JAR_ROOT "${WPILIB_BINARY_DIR}/wpimath/thirdparty/ejml")
message(STATUS "Downloading EJML jarfiles...")
download_and_check(
"${BASE_URL}org/ejml/ejml-cdense/0.43.1/ejml-cdense-0.43.1.jar"
"${JAR_ROOT}/ejml-cdense-0.43.1.jar"
"${BASE_URL}org/ejml/ejml-cdense/0.44.0/ejml-cdense-0.44.0.jar"
"${JAR_ROOT}/ejml-cdense-0.44.0.jar"
)
download_and_check(
"${BASE_URL}org/ejml/ejml-core/0.43.1/ejml-core-0.43.1.jar"
"${JAR_ROOT}/ejml-core-0.43.1.jar"
"${BASE_URL}org/ejml/ejml-core/0.44.0/ejml-core-0.44.0.jar"
"${JAR_ROOT}/ejml-core-0.44.0.jar"
)
download_and_check(
"${BASE_URL}org/ejml/ejml-ddense/0.43.1/ejml-ddense-0.43.1.jar"
"${JAR_ROOT}/ejml-ddense-0.43.1.jar"
"${BASE_URL}org/ejml/ejml-ddense/0.44.0/ejml-ddense-0.44.0.jar"
"${JAR_ROOT}/ejml-ddense-0.44.0.jar"
)
download_and_check(
"${BASE_URL}org/ejml/ejml-dsparse/0.43.1/ejml-dsparse-0.43.1.jar"
"${JAR_ROOT}/ejml-dsparse-0.43.1.jar"
"${BASE_URL}org/ejml/ejml-dsparse/0.44.0/ejml-dsparse-0.44.0.jar"
"${JAR_ROOT}/ejml-dsparse-0.44.0.jar"
)
download_and_check(
"${BASE_URL}org/ejml/ejml-fdense/0.43.1/ejml-fdense-0.43.1.jar"
"${JAR_ROOT}/ejml-fdense-0.43.1.jar"
"${BASE_URL}org/ejml/ejml-fdense/0.44.0/ejml-fdense-0.44.0.jar"
"${JAR_ROOT}/ejml-fdense-0.44.0.jar"
)
download_and_check(
"${BASE_URL}org/ejml/ejml-simple/0.43.1/ejml-simple-0.43.1.jar"
"${JAR_ROOT}/ejml-simple-0.43.1.jar"
"${BASE_URL}org/ejml/ejml-simple/0.44.0/ejml-simple-0.44.0.jar"
"${JAR_ROOT}/ejml-simple-0.44.0.jar"
)
download_and_check(
"${BASE_URL}org/ejml/ejml-zdense/0.43.1/ejml-zdense-0.43.1.jar"
"${JAR_ROOT}/ejml-zdense-0.43.1.jar"
"${BASE_URL}org/ejml/ejml-zdense/0.44.0/ejml-zdense-0.44.0.jar"
"${JAR_ROOT}/ejml-zdense-0.44.0.jar"
)
message(STATUS "All files downloaded.")

66
wpimath/README.md Normal file
View File

@@ -0,0 +1,66 @@
# wpimath
wpimath contains utilities for robot control (feedforward/feedback), state estimation (filters, Kalman and otherwise), 2D/3D geometry, kinematics, trajectory generation, and trajectory optimization.
## Implementation guidelines
A lot of wpimath features directly implement equations from books or papers. The following guidelines make that code easier to maintain and audit for correctness.
### Citations
Cite source books/papers at the top of the function (e.g., `See section 5.6 of "book name".`). If multiple items from a given work are referenced, write a bibliography entry for the work to reference later.
Cite the equation numbers each line of code implements, if applicable. For example, `See equation (#.#) of [1].` where `[1]` is a bibliography reference number.
### Comments
Comment each line of code with its pretty-printed math equivalent.
```cpp
// xₖ₊₁ = Axₖ + Buₖ
x = A * x + B * u;
```
Link to explanatory material where appropriate to explain background knowledge and/or notation choice.
### Variable naming
Follow established mathematical convention where possible (e.g., use A, B, C, D for state-space notation instead of `stateTransitionMatrix`).
Use math symbols in variable names (see [Unicodeit](#Unicodeit)) to match source papers. This usually entails some Greek letters (α), but diacritics, superscripts, and subscripts need to be spelled out (`ẋ``x_dot`, `αₖ²``α_k_sq`) since compilers reject them, and the small features make variable names difficult to read.
### Derivations
Put small derivations in comments within the function. Put large derivations in algorithms.md and link to them.
## Unicodeit
When writing math expressions in documentation, use https://www.unicodeit.net/ to convert LaTeX to a Unicode equivalent that's easier to read. Not all expressions will translate (e.g., superscripts of superscripts) so focus on making it readable by someone who isn't familiar with LaTeX. If content on multiple lines needs to be aligned in Doxygen/Javadoc comments (e.g., integration/summation limits, matrices packed with square brackets and superscripts for them), put them in @verbatim/@endverbatim blocks in Doxygen or `<pre>` tags in Javadoc so they render with monospace font.
The LaTeX to Unicode conversions can also be done locally via the unicodeit Python package. To install it, execute:
```bash
pip install --user unicodeit
```
Here's example usage:
```bash
$ python -m unicodeit.cli 'x_{k+1} = Ax_k + Bu_k'
xₖ₊₁ = Axₖ + Buₖ
```
On Linux, this process can be streamlined further by adding the following Bash function to your .bashrc (requires `wl-clipboard` on Wayland or `xclip` on X11):
```bash
# Converts LaTeX to Unicode, prints the result, and copies it to the clipboard
uc() {
if [ $WAYLAND_DISPLAY ]; then
python -m unicodeit.cli $@ | tee >(wl-copy -n)
else
python -m unicodeit.cli $@ | tee >(xclip -sel)
fi
}
```
Here's example usage:
```bash
$ uc 'x_{k+1} = Ax_k + Bu_k'
xₖ₊₁ = Axₖ + Buₖ
```

View File

@@ -89,11 +89,11 @@ nativeUtils.exportsConfigs {
dependencies {
api project(":wpiunits")
api "org.ejml:ejml-simple:0.43.1"
api "com.fasterxml.jackson.core:jackson-annotations:2.15.2"
api "com.fasterxml.jackson.core:jackson-core:2.15.2"
api "com.fasterxml.jackson.core:jackson-databind:2.15.2"
api "us.hebi.quickbuf:quickbuf-runtime:1.3.3"
api "org.ejml:ejml-simple:0.44.0"
api "com.fasterxml.jackson.core:jackson-annotations:2.19.2"
api "com.fasterxml.jackson.core:jackson-core:2.19.2"
api "com.fasterxml.jackson.core:jackson-databind:2.19.2"
api "us.hebi.quickbuf:quickbuf-runtime:1.4"
}
sourceSets.main.java.srcDir "${projectDir}/src/generated/main/java"

View File

@@ -124,7 +124,8 @@ public final class LinearSystemId {
* @param kA The acceleration gain, in volts/(unit/sec²)
* @return A LinearSystem representing the given characterized constants.
* @throws IllegalArgumentException if kV &lt; 0 or kA &lt;= 0.
* @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* @see <a
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
public static LinearSystem<N2, N1, N2> createDCMotorSystem(double kV, double kA) {
if (kV < 0.0) {
@@ -235,7 +236,8 @@ public final class LinearSystemId {
* @param kA The acceleration gain, in volts/(unit/sec²)
* @return A LinearSystem representing the given characterized constants.
* @throws IllegalArgumentException if kV &lt; 0 or kA &lt;= 0.
* @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* @see <a
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
public static LinearSystem<N1, N1, N1> identifyVelocitySystem(double kV, double kA) {
if (kV < 0.0) {
@@ -268,7 +270,8 @@ public final class LinearSystemId {
* @param kA The acceleration gain, in volts/(unit/sec²)
* @return A LinearSystem representing the given characterized constants.
* @throws IllegalArgumentException if kV &lt; 0 or kA &lt;= 0.
* @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* @see <a
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
public static LinearSystem<N2, N1, N2> identifyPositionSystem(double kV, double kA) {
if (kV < 0.0) {
@@ -301,7 +304,8 @@ public final class LinearSystemId {
* @return A LinearSystem representing the given characterized constants.
* @throws IllegalArgumentException if kVLinear &lt;= 0, kALinear &lt;= 0, kVAngular &lt;= 0, or
* kAAngular &lt;= 0.
* @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* @see <a
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
public static LinearSystem<N2, N2, N2> identifyDrivetrainSystem(
double kVLinear, double kALinear, double kVAngular, double kAAngular) {
@@ -348,7 +352,8 @@ public final class LinearSystemId {
* @return A LinearSystem representing the given characterized constants.
* @throws IllegalArgumentException if kVLinear &lt;= 0, kALinear &lt;= 0, kVAngular &lt;= 0,
* kAAngular &lt;= 0, or trackwidth &lt;= 0.
* @see <a href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* @see <a
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
public static LinearSystem<N2, N2, N2> identifyDrivetrainSystem(
double kVLinear, double kALinear, double kVAngular, double kAAngular, double trackwidth) {

View File

@@ -4,6 +4,8 @@
package edu.wpi.first.math.trajectory;
import edu.wpi.first.math.trajectory.struct.ExponentialProfileStateStruct;
import edu.wpi.first.util.struct.StructSerializable;
import java.util.Objects;
/**
@@ -128,7 +130,10 @@ public class ExponentialProfile {
}
/** Profile state. */
public static class State {
public static class State implements StructSerializable {
/** The struct that serializes this class. */
public static final ExponentialProfileStateStruct struct = new ExponentialProfileStateStruct();
/** The position at this state. */
public double position;

View File

@@ -5,6 +5,7 @@
package edu.wpi.first.math.trajectory;
import edu.wpi.first.math.MathSharedStore;
import edu.wpi.first.math.trajectory.struct.TrapezoidProfileStateStruct;
import java.util.Objects;
/**
@@ -75,6 +76,9 @@ public class TrapezoidProfile {
/** Profile state. */
public static class State {
/** The struct used to serialize this class. */
public static final TrapezoidProfileStateStruct struct = new TrapezoidProfileStateStruct();
/** The position at this state. */
public double position;

View File

@@ -0,0 +1,42 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.trajectory.struct;
import edu.wpi.first.math.trajectory.ExponentialProfile;
import edu.wpi.first.util.struct.Struct;
import java.nio.ByteBuffer;
public class ExponentialProfileStateStruct implements Struct<ExponentialProfile.State> {
@Override
public Class<ExponentialProfile.State> getTypeClass() {
return ExponentialProfile.State.class;
}
@Override
public String getTypeName() {
return "ExponentialProfileState";
}
@Override
public int getSize() {
return kSizeDouble * 2;
}
@Override
public String getSchema() {
return "double position;double velocity";
}
@Override
public ExponentialProfile.State unpack(ByteBuffer bb) {
return new ExponentialProfile.State(bb.getDouble(), bb.getDouble());
}
@Override
public void pack(ByteBuffer bb, ExponentialProfile.State value) {
bb.putDouble(value.position);
bb.putDouble(value.velocity);
}
}

View File

@@ -0,0 +1,42 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.trajectory.struct;
import edu.wpi.first.math.trajectory.TrapezoidProfile;
import edu.wpi.first.util.struct.Struct;
import java.nio.ByteBuffer;
public class TrapezoidProfileStateStruct implements Struct<TrapezoidProfile.State> {
@Override
public Class<TrapezoidProfile.State> getTypeClass() {
return TrapezoidProfile.State.class;
}
@Override
public String getTypeName() {
return "TrapezoidProfileState";
}
@Override
public int getSize() {
return kSizeDouble * 2;
}
@Override
public String getSchema() {
return "double position;double velocity";
}
@Override
public TrapezoidProfile.State unpack(ByteBuffer bb) {
return new TrapezoidProfile.State(bb.getDouble(), bb.getDouble());
}
@Override
public void pack(ByteBuffer bb, TrapezoidProfile.State value) {
bb.putDouble(value.position);
bb.putDouble(value.velocity);
}
}

View File

@@ -122,7 +122,7 @@ class WPILIB_DLLEXPORT LinearSystemId {
* @param kA The acceleration gain, in volts/(unit/sec²).
* @throws std::domain_error if kV < 0 or kA <= 0.
* @see <a
* href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
template <typename Distance>
requires std::same_as<units::meter, Distance> ||
@@ -165,7 +165,7 @@ class WPILIB_DLLEXPORT LinearSystemId {
*
* @throws std::domain_error if kV < 0 or kA <= 0.
* @see <a
* href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
template <typename Distance>
requires std::same_as<units::meter, Distance> ||
@@ -208,7 +208,7 @@ class WPILIB_DLLEXPORT LinearSystemId {
* @throws domain_error if kVLinear <= 0, kALinear <= 0, kVAngular <= 0,
* or kAAngular <= 0.
* @see <a
* href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
static constexpr LinearSystem<2, 2, 2> IdentifyDrivetrainSystem(
decltype(1_V / 1_mps) kVLinear, decltype(1_V / 1_mps_sq) kALinear,
@@ -269,7 +269,7 @@ class WPILIB_DLLEXPORT LinearSystemId {
* @throws domain_error if kVLinear <= 0, kALinear <= 0, kVAngular <= 0,
* kAAngular <= 0, or trackwidth <= 0.
* @see <a
* href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
static constexpr LinearSystem<2, 2, 2> IdentifyDrivetrainSystem(
decltype(1_V / 1_mps) kVLinear, decltype(1_V / 1_mps_sq) kALinear,
@@ -346,7 +346,7 @@ class WPILIB_DLLEXPORT LinearSystemId {
* @param gearing Gear ratio from motor to output.
* @throws std::domain_error if J <= 0 or gearing <= 0.
* @see <a
* href="https://github.com/wpilibsuite/sysid">https://github.com/wpilibsuite/sysid</a>
* href="https://github.com/wpilibsuite/allwpilib/tree/main/sysid">https://github.com/wpilibsuite/allwpilib/tree/main/sysid</a>
*/
static constexpr LinearSystem<2, 1, 2> DCMotorSystem(
DCMotor motor, units::kilogram_square_meter_t J, double gearing) {

View File

@@ -92,6 +92,7 @@
#include <algorithm>
#include <array>
#include <memory>
#include <vector>
// for std::is_nothrow_move_assignable
@@ -102,6 +103,11 @@
#include <thread>
#endif
// for std::bit_cast()
#if defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L
#include <bit>
#endif
// for outputting debug info
#ifdef EIGEN_DEBUG_ASSIGN
#include <iostream>
@@ -121,7 +127,6 @@
#undef isfinite
#include <CL/sycl.hpp>
#include <map>
#include <memory>
#include <thread>
#include <utility>
#ifndef EIGEN_SYCL_LOCAL_THREAD_DIM0
@@ -194,8 +199,11 @@ using std::ptrdiff_t;
#if defined EIGEN_VECTORIZE_AVX512
#include "src/Core/arch/SSE/PacketMath.h"
#include "src/Core/arch/SSE/Reductions.h"
#include "src/Core/arch/AVX/PacketMath.h"
#include "src/Core/arch/AVX/Reductions.h"
// #include "src/Core/arch/AVX512/PacketMath.h"
// #include "src/Core/arch/AVX512/Reductions.h"
#if defined EIGEN_VECTORIZE_AVX512FP16
// #include "src/Core/arch/AVX512/PacketMathFP16.h"
#endif
@@ -216,21 +224,26 @@ using std::ptrdiff_t;
#endif
// #include "src/Core/arch/AVX512/TrsmKernel.h"
#elif defined EIGEN_VECTORIZE_AVX
// Use AVX for floats and doubles, SSE for integers
// Use AVX for floats and doubles, SSE for integers
#include "src/Core/arch/SSE/PacketMath.h"
#include "src/Core/arch/SSE/Reductions.h"
#include "src/Core/arch/SSE/TypeCasting.h"
#include "src/Core/arch/SSE/Complex.h"
#include "src/Core/arch/AVX/PacketMath.h"
#include "src/Core/arch/AVX/Reductions.h"
#include "src/Core/arch/AVX/TypeCasting.h"
#include "src/Core/arch/AVX/Complex.h"
#include "src/Core/arch/SSE/MathFunctions.h"
#include "src/Core/arch/AVX/MathFunctions.h"
#elif defined EIGEN_VECTORIZE_SSE
#include "src/Core/arch/SSE/PacketMath.h"
#include "src/Core/arch/SSE/Reductions.h"
#include "src/Core/arch/SSE/TypeCasting.h"
#include "src/Core/arch/SSE/MathFunctions.h"
#include "src/Core/arch/SSE/Complex.h"
#elif defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
#endif
#if defined(EIGEN_VECTORIZE_ALTIVEC) || defined(EIGEN_VECTORIZE_VSX)
// #include "src/Core/arch/AltiVec/PacketMath.h"
// #include "src/Core/arch/AltiVec/TypeCasting.h"
// #include "src/Core/arch/AltiVec/MathFunctions.h"
@@ -311,6 +324,7 @@ using std::ptrdiff_t;
#include "src/Core/Product.h"
#include "src/Core/CoreEvaluators.h"
#include "src/Core/AssignEvaluator.h"
#include "src/Core/RealView.h"
#include "src/Core/Assign.h"
#include "src/Core/ArrayBase.h"
@@ -350,6 +364,7 @@ using std::ptrdiff_t;
#include "src/Core/SkewSymmetricMatrix3.h"
#include "src/Core/Redux.h"
#include "src/Core/Visitor.h"
#include "src/Core/FindCoeff.h"
#include "src/Core/Fuzzy.h"
#include "src/Core/Swap.h"
#include "src/Core/CommaInitializer.h"

View File

@@ -707,7 +707,7 @@ struct unary_evaluator<CwiseUnaryOp<core_cast_op<SrcType, DstType>, ArgType>, In
Index packetOffset = offset * PacketSize;
Index actualRow = IsRowMajor ? row : row + packetOffset;
Index actualCol = IsRowMajor ? col + packetOffset : col;
eigen_assert(check_array_bounds(actualRow, actualCol, 0, count) && "Array index out of bounds");
eigen_assert(check_array_bounds(actualRow, actualCol, begin, count) && "Array index out of bounds");
return m_argImpl.template packetSegment<LoadMode, PacketType>(actualRow, actualCol, begin, count);
}
template <int LoadMode, typename PacketType = SrcPacketType>
@@ -715,8 +715,8 @@ struct unary_evaluator<CwiseUnaryOp<core_cast_op<SrcType, DstType>, ArgType>, In
Index offset) const {
constexpr int PacketSize = unpacket_traits<PacketType>::size;
Index packetOffset = offset * PacketSize;
Index actualIndex = index + packetOffset + begin;
eigen_assert(check_array_bounds(actualIndex, 0, count) && "Array index out of bounds");
Index actualIndex = index + packetOffset;
eigen_assert(check_array_bounds(actualIndex, begin, count) && "Array index out of bounds");
return m_argImpl.template packetSegment<LoadMode, PacketType>(actualIndex, begin, count);
}
@@ -1565,50 +1565,6 @@ struct block_evaluator<ArgType, BlockRows, BlockCols, InnerPanel, /* HasDirectAc
}
};
// -------------------- Select --------------------
// NOTE shall we introduce a ternary_evaluator?
// TODO enable vectorization for Select
template <typename ConditionMatrixType, typename ThenMatrixType, typename ElseMatrixType>
struct evaluator<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType>>
: evaluator_base<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType>> {
typedef Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType> XprType;
enum {
CoeffReadCost = evaluator<ConditionMatrixType>::CoeffReadCost +
plain_enum_max(evaluator<ThenMatrixType>::CoeffReadCost, evaluator<ElseMatrixType>::CoeffReadCost),
Flags = (unsigned int)evaluator<ThenMatrixType>::Flags & evaluator<ElseMatrixType>::Flags & HereditaryBits,
Alignment = plain_enum_min(evaluator<ThenMatrixType>::Alignment, evaluator<ElseMatrixType>::Alignment)
};
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE explicit evaluator(const XprType& select)
: m_conditionImpl(select.conditionMatrix()), m_thenImpl(select.thenMatrix()), m_elseImpl(select.elseMatrix()) {
EIGEN_INTERNAL_CHECK_COST_VALUE(CoeffReadCost);
}
typedef typename XprType::CoeffReturnType CoeffReturnType;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index row, Index col) const {
if (m_conditionImpl.coeff(row, col))
return m_thenImpl.coeff(row, col);
else
return m_elseImpl.coeff(row, col);
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE CoeffReturnType coeff(Index index) const {
if (m_conditionImpl.coeff(index))
return m_thenImpl.coeff(index);
else
return m_elseImpl.coeff(index);
}
protected:
evaluator<ConditionMatrixType> m_conditionImpl;
evaluator<ThenMatrixType> m_thenImpl;
evaluator<ElseMatrixType> m_elseImpl;
};
// -------------------- Replicate --------------------
template <typename ArgType, int RowFactor, int ColFactor>

View File

@@ -367,7 +367,12 @@ class DenseBase
EIGEN_DEVICE_FUNC inline bool allFinite() const;
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const Scalar& other);
template <bool Enable = !internal::is_same<Scalar, RealScalar>::value, typename = std::enable_if_t<Enable>>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator*=(const RealScalar& other);
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const Scalar& other);
template <bool Enable = !internal::is_same<Scalar, RealScalar>::value, typename = std::enable_if_t<Enable>>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& operator/=(const RealScalar& other);
typedef internal::add_const_on_value_type_t<typename internal::eval<Derived>::type> EvalReturnType;
/** \returns the matrix or vector obtained by evaluating this expression.
@@ -597,6 +602,13 @@ class DenseBase
inline const_iterator end() const;
inline const_iterator cend() const;
using RealViewReturnType = std::conditional_t<NumTraits<Scalar>::IsComplex, RealView<Derived>, Derived&>;
using ConstRealViewReturnType =
std::conditional_t<NumTraits<Scalar>::IsComplex, RealView<const Derived>, const Derived&>;
EIGEN_DEVICE_FUNC RealViewReturnType realView();
EIGEN_DEVICE_FUNC ConstRealViewReturnType realView() const;
#define EIGEN_CURRENT_STORAGE_BASE_CLASS Eigen::DenseBase
#define EIGEN_DOC_BLOCK_ADDONS_NOT_INNER_PANEL
#define EIGEN_DOC_BLOCK_ADDONS_INNER_PANEL_IF(COND)

View File

@@ -45,10 +45,16 @@ class DenseCoeffsBase<Derived, ReadOnlyAccessors> : public EigenBase<Derived> {
// - This is the return type of the coeff() method.
// - The LvalueBit means exactly that we can offer a coeffRef() method, which means exactly that we can get references
// to coeffs, which means exactly that we can have coeff() return a const reference (as opposed to returning a value).
// - The DirectAccessBit means exactly that the underlying data of coefficients can be directly accessed as a plain
// strided array, which means exactly that the underlying data of coefficients does exist in memory, which means
// exactly that the coefficients is const-referencable, which means exactly that we can have coeff() return a const
// reference. For example, Map<const Matrix> have DirectAccessBit but not LvalueBit, so that Map<const Matrix>.coeff()
// does points to a const Scalar& which exists in memory, while does not allow coeffRef() as it would not provide a
// lvalue. Notice that DirectAccessBit and LvalueBit are mutually orthogonal.
// - The is_arithmetic check is required since "const int", "const double", etc. will cause warnings on some systems
// while the declaration of "const T", where T is a non arithmetic type does not. Always returning "const Scalar&" is
// not possible, since the underlying expressions might not offer a valid address the reference could be referring to.
typedef std::conditional_t<bool(internal::traits<Derived>::Flags& LvalueBit), const Scalar&,
typedef std::conditional_t<bool(internal::traits<Derived>::Flags&(LvalueBit | DirectAccessBit)), const Scalar&,
std::conditional_t<internal::is_arithmetic<Scalar>::value, Scalar, const Scalar>>
CoeffReturnType;

View File

@@ -78,8 +78,9 @@ template <typename Xpr>
struct eigen_fill_impl<Xpr, /*use_fill*/ true> {
using Scalar = typename Xpr::Scalar;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const Scalar& val) {
const Scalar val_copy = val;
using std::fill_n;
fill_n(dst.data(), dst.size(), val);
fill_n(dst.data(), dst.size(), val_copy);
}
template <typename SrcXpr>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(Xpr& dst, const SrcXpr& src) {

View File

@@ -0,0 +1,464 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 Charlie Schlosser <cs.schlosser@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_FIND_COEFF_H
#define EIGEN_FIND_COEFF_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
template <typename Scalar, int NaNPropagation, bool IsInteger = NumTraits<Scalar>::IsInteger>
struct max_coeff_functor {
EIGEN_DEVICE_FUNC inline bool compareCoeff(const Scalar& incumbent, const Scalar& candidate) const {
return candidate > incumbent;
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet comparePacket(const Packet& incumbent, const Packet& candidate) const {
return pcmp_lt(incumbent, candidate);
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& a) const {
return predux_max(a);
}
};
template <typename Scalar>
struct max_coeff_functor<Scalar, PropagateNaN, false> {
EIGEN_DEVICE_FUNC inline Scalar compareCoeff(const Scalar& incumbent, const Scalar& candidate) {
return (candidate > incumbent) || ((candidate != candidate) && (incumbent == incumbent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet comparePacket(const Packet& incumbent, const Packet& candidate) {
return pandnot(pcmp_lt_or_nan(incumbent, candidate), pisnan(incumbent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& a) const {
return predux_max<PropagateNaN>(a);
}
};
template <typename Scalar>
struct max_coeff_functor<Scalar, PropagateNumbers, false> {
EIGEN_DEVICE_FUNC inline bool compareCoeff(const Scalar& incumbent, const Scalar& candidate) const {
return (candidate > incumbent) || ((candidate == candidate) && (incumbent != incumbent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet comparePacket(const Packet& incumbent, const Packet& candidate) const {
return pandnot(pcmp_lt_or_nan(incumbent, candidate), pisnan(candidate));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& a) const {
return predux_max<PropagateNumbers>(a);
}
};
template <typename Scalar, int NaNPropagation, bool IsInteger = NumTraits<Scalar>::IsInteger>
struct min_coeff_functor {
EIGEN_DEVICE_FUNC inline bool compareCoeff(const Scalar& incumbent, const Scalar& candidate) const {
return candidate < incumbent;
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet comparePacket(const Packet& incumbent, const Packet& candidate) const {
return pcmp_lt(candidate, incumbent);
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& a) const {
return predux_min(a);
}
};
template <typename Scalar>
struct min_coeff_functor<Scalar, PropagateNaN, false> {
EIGEN_DEVICE_FUNC inline Scalar compareCoeff(const Scalar& incumbent, const Scalar& candidate) {
return (candidate < incumbent) || ((candidate != candidate) && (incumbent == incumbent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet comparePacket(const Packet& incumbent, const Packet& candidate) {
return pandnot(pcmp_lt_or_nan(candidate, incumbent), pisnan(incumbent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& a) const {
return predux_min<PropagateNaN>(a);
}
};
template <typename Scalar>
struct min_coeff_functor<Scalar, PropagateNumbers, false> {
EIGEN_DEVICE_FUNC inline bool compareCoeff(const Scalar& incumbent, const Scalar& candidate) const {
return (candidate < incumbent) || ((candidate == candidate) && (incumbent != incumbent));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet comparePacket(const Packet& incumbent, const Packet& candidate) const {
return pandnot(pcmp_lt_or_nan(candidate, incumbent), pisnan(candidate));
}
template <typename Packet>
EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& a) const {
return predux_min<PropagateNumbers>(a);
}
};
template <typename Scalar>
struct min_max_traits {
static constexpr bool PacketAccess = packet_traits<Scalar>::Vectorizable;
};
template <typename Scalar, int NaNPropagation>
struct functor_traits<max_coeff_functor<Scalar, NaNPropagation>> : min_max_traits<Scalar> {};
template <typename Scalar, int NaNPropagation>
struct functor_traits<min_coeff_functor<Scalar, NaNPropagation>> : min_max_traits<Scalar> {};
template <typename Evaluator, typename Func, bool Linear, bool Vectorize>
struct find_coeff_loop;
template <typename Evaluator, typename Func>
struct find_coeff_loop<Evaluator, Func, /*Linear*/ false, /*Vectorize*/ false> {
using Scalar = typename Evaluator::Scalar;
static EIGEN_DEVICE_FUNC inline void run(const Evaluator& eval, Func& func, Scalar& res, Index& outer, Index& inner) {
Index outerSize = eval.outerSize();
Index innerSize = eval.innerSize();
/* initialization performed in calling function */
/* result = eval.coeff(0, 0); */
/* outer = 0; */
/* inner = 0; */
for (Index j = 0; j < outerSize; j++) {
for (Index i = 0; i < innerSize; i++) {
Scalar xprCoeff = eval.coeffByOuterInner(j, i);
bool newRes = func.compareCoeff(res, xprCoeff);
if (newRes) {
outer = j;
inner = i;
res = xprCoeff;
}
}
}
}
};
template <typename Evaluator, typename Func>
struct find_coeff_loop<Evaluator, Func, /*Linear*/ true, /*Vectorize*/ false> {
using Scalar = typename Evaluator::Scalar;
static EIGEN_DEVICE_FUNC inline void run(const Evaluator& eval, Func& func, Scalar& res, Index& index) {
Index size = eval.size();
/* initialization performed in calling function */
/* result = eval.coeff(0); */
/* index = 0; */
for (Index k = 0; k < size; k++) {
Scalar xprCoeff = eval.coeff(k);
bool newRes = func.compareCoeff(res, xprCoeff);
if (newRes) {
index = k;
res = xprCoeff;
}
}
}
};
template <typename Evaluator, typename Func>
struct find_coeff_loop<Evaluator, Func, /*Linear*/ false, /*Vectorize*/ true> {
using ScalarImpl = find_coeff_loop<Evaluator, Func, false, false>;
using Scalar = typename Evaluator::Scalar;
using Packet = typename Evaluator::Packet;
static constexpr int PacketSize = unpacket_traits<Packet>::size;
static EIGEN_DEVICE_FUNC inline void run(const Evaluator& eval, Func& func, Scalar& result, Index& outer,
Index& inner) {
Index outerSize = eval.outerSize();
Index innerSize = eval.innerSize();
Index packetEnd = numext::round_down(innerSize, PacketSize);
/* initialization performed in calling function */
/* result = eval.coeff(0, 0); */
/* outer = 0; */
/* inner = 0; */
bool checkPacket = false;
for (Index j = 0; j < outerSize; j++) {
Packet resultPacket = pset1<Packet>(result);
for (Index i = 0; i < packetEnd; i += PacketSize) {
Packet xprPacket = eval.template packetByOuterInner<Unaligned, Packet>(j, i);
if (predux_any(func.comparePacket(resultPacket, xprPacket))) {
outer = j;
inner = i;
result = func.predux(xprPacket);
resultPacket = pset1<Packet>(result);
checkPacket = true;
}
}
for (Index i = packetEnd; i < innerSize; i++) {
Scalar xprCoeff = eval.coeffByOuterInner(j, i);
if (func.compareCoeff(result, xprCoeff)) {
outer = j;
inner = i;
result = xprCoeff;
checkPacket = false;
}
}
}
if (checkPacket) {
result = eval.coeffByOuterInner(outer, inner);
Index i_end = inner + PacketSize;
for (Index i = inner; i < i_end; i++) {
Scalar xprCoeff = eval.coeffByOuterInner(outer, i);
if (func.compareCoeff(result, xprCoeff)) {
inner = i;
result = xprCoeff;
}
}
}
}
};
template <typename Evaluator, typename Func>
struct find_coeff_loop<Evaluator, Func, /*Linear*/ true, /*Vectorize*/ true> {
using ScalarImpl = find_coeff_loop<Evaluator, Func, true, false>;
using Scalar = typename Evaluator::Scalar;
using Packet = typename Evaluator::Packet;
static constexpr int PacketSize = unpacket_traits<Packet>::size;
static constexpr int Alignment = Evaluator::Alignment;
static EIGEN_DEVICE_FUNC inline void run(const Evaluator& eval, Func& func, Scalar& result, Index& index) {
Index size = eval.size();
Index packetEnd = numext::round_down(size, PacketSize);
/* initialization performed in calling function */
/* result = eval.coeff(0); */
/* index = 0; */
Packet resultPacket = pset1<Packet>(result);
bool checkPacket = false;
for (Index k = 0; k < packetEnd; k += PacketSize) {
Packet xprPacket = eval.template packet<Alignment, Packet>(k);
if (predux_any(func.comparePacket(resultPacket, xprPacket))) {
index = k;
result = func.predux(xprPacket);
resultPacket = pset1<Packet>(result);
checkPacket = true;
}
}
for (Index k = packetEnd; k < size; k++) {
Scalar xprCoeff = eval.coeff(k);
if (func.compareCoeff(result, xprCoeff)) {
index = k;
result = xprCoeff;
checkPacket = false;
}
}
if (checkPacket) {
result = eval.coeff(index);
Index k_end = index + PacketSize;
for (Index k = index; k < k_end; k++) {
Scalar xprCoeff = eval.coeff(k);
if (func.compareCoeff(result, xprCoeff)) {
index = k;
result = xprCoeff;
}
}
}
}
};
template <typename Derived>
struct find_coeff_evaluator : public evaluator<Derived> {
using Base = evaluator<Derived>;
using Scalar = typename Derived::Scalar;
using Packet = typename packet_traits<Scalar>::type;
static constexpr int Flags = Base::Flags;
static constexpr bool IsRowMajor = bool(Flags & RowMajorBit);
EIGEN_DEVICE_FUNC inline find_coeff_evaluator(const Derived& xpr) : Base(xpr), m_xpr(xpr) {}
EIGEN_DEVICE_FUNC inline Scalar coeffByOuterInner(Index outer, Index inner) const {
Index row = IsRowMajor ? outer : inner;
Index col = IsRowMajor ? inner : outer;
return Base::coeff(row, col);
}
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC inline PacketType packetByOuterInner(Index outer, Index inner) const {
Index row = IsRowMajor ? outer : inner;
Index col = IsRowMajor ? inner : outer;
return Base::template packet<LoadMode, PacketType>(row, col);
}
EIGEN_DEVICE_FUNC inline Index innerSize() const { return m_xpr.innerSize(); }
EIGEN_DEVICE_FUNC inline Index outerSize() const { return m_xpr.outerSize(); }
EIGEN_DEVICE_FUNC inline Index size() const { return m_xpr.size(); }
const Derived& m_xpr;
};
template <typename Derived, typename Func>
struct find_coeff_impl {
using Evaluator = find_coeff_evaluator<Derived>;
static constexpr int Flags = Evaluator::Flags;
static constexpr int Alignment = Evaluator::Alignment;
static constexpr bool IsRowMajor = Derived::IsRowMajor;
static constexpr int MaxInnerSizeAtCompileTime =
IsRowMajor ? Derived::MaxColsAtCompileTime : Derived::MaxRowsAtCompileTime;
static constexpr int MaxSizeAtCompileTime = Derived::MaxSizeAtCompileTime;
using Scalar = typename Derived::Scalar;
using Packet = typename Evaluator::Packet;
static constexpr int PacketSize = unpacket_traits<Packet>::size;
static constexpr bool Linearize = bool(Flags & LinearAccessBit);
static constexpr bool DontVectorize =
enum_lt_not_dynamic(Linearize ? MaxSizeAtCompileTime : MaxInnerSizeAtCompileTime, PacketSize);
static constexpr bool Vectorize =
!DontVectorize && bool(Flags & PacketAccessBit) && functor_traits<Func>::PacketAccess;
using Loop = find_coeff_loop<Evaluator, Func, Linearize, Vectorize>;
template <bool ForwardLinearAccess = Linearize, std::enable_if_t<!ForwardLinearAccess, bool> = true>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& xpr, Func& func, Scalar& res, Index& outer,
Index& inner) {
Evaluator eval(xpr);
Loop::run(eval, func, res, outer, inner);
}
template <bool ForwardLinearAccess = Linearize, std::enable_if_t<ForwardLinearAccess, bool> = true>
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& xpr, Func& func, Scalar& res, Index& outer,
Index& inner) {
// where possible, use the linear loop and back-calculate the outer and inner indices
Index index = 0;
run(xpr, func, res, index);
outer = index / xpr.innerSize();
inner = index % xpr.innerSize();
}
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void run(const Derived& xpr, Func& func, Scalar& res, Index& index) {
Evaluator eval(xpr);
Loop::run(eval, func, res, index);
}
};
template <typename Derived, typename IndexType, typename Func>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar findCoeff(const DenseBase<Derived>& mat, Func& func,
IndexType* rowPtr, IndexType* colPtr) {
eigen_assert(mat.rows() > 0 && mat.cols() > 0 && "you are using an empty matrix");
using Scalar = typename DenseBase<Derived>::Scalar;
using FindCoeffImpl = internal::find_coeff_impl<Derived, Func>;
Index outer = 0;
Index inner = 0;
Scalar res = mat.coeff(0, 0);
FindCoeffImpl::run(mat.derived(), func, res, outer, inner);
*rowPtr = internal::convert_index<IndexType>(Derived::IsRowMajor ? outer : inner);
if (colPtr) *colPtr = internal::convert_index<IndexType>(Derived::IsRowMajor ? inner : outer);
return res;
}
template <typename Derived, typename IndexType, typename Func>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar findCoeff(const DenseBase<Derived>& mat, Func& func,
IndexType* indexPtr) {
eigen_assert(mat.size() > 0 && "you are using an empty matrix");
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
using Scalar = typename DenseBase<Derived>::Scalar;
using FindCoeffImpl = internal::find_coeff_impl<Derived, Func>;
Index index = 0;
Scalar res = mat.coeff(0);
FindCoeffImpl::run(mat.derived(), func, res, index);
*indexPtr = internal::convert_index<IndexType>(index);
return res;
}
} // namespace internal
/** \fn DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
* \returns the minimum of all coefficients of *this and puts in *row and *col its location.
*
* If there are multiple coefficients with the same extreme value, the location of the first instance is returned.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(Index*), DenseBase::maxCoeff(Index*,Index*), DenseBase::visit(), DenseBase::minCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::minCoeff(IndexType* rowPtr,
IndexType* colPtr) const {
using Func = internal::min_coeff_functor<Scalar, NaNPropagation>;
Func func;
return internal::findCoeff(derived(), func, rowPtr, colPtr);
}
/** \returns the minimum of all coefficients of *this and puts in *index its location.
*
* If there are multiple coefficients with the same extreme value, the location of the first instance is returned.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::visit(),
* DenseBase::minCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::minCoeff(IndexType* indexPtr) const {
using Func = internal::min_coeff_functor<Scalar, NaNPropagation>;
Func func;
return internal::findCoeff(derived(), func, indexPtr);
}
/** \fn DenseBase<Derived>::maxCoeff(IndexType* rowId, IndexType* colId) const
* \returns the maximum of all coefficients of *this and puts in *row and *col its location.
*
* If there are multiple coefficients with the same extreme value, the location of the first instance is returned.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::maxCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::maxCoeff(IndexType* rowPtr,
IndexType* colPtr) const {
using Func = internal::max_coeff_functor<Scalar, NaNPropagation>;
Func func;
return internal::findCoeff(derived(), func, rowPtr, colPtr);
}
/** \returns the maximum of all coefficients of *this and puts in *index its location.
*
* If there are multiple coefficients with the same extreme value, the location of the first instance is returned.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visitor(),
* DenseBase::maxCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::maxCoeff(IndexType* indexPtr) const {
using Func = internal::max_coeff_functor<Scalar, NaNPropagation>;
Func func;
return internal::findCoeff(derived(), func, indexPtr);
}
} // namespace Eigen
#endif // EIGEN_FIND_COEFF_H

View File

@@ -253,6 +253,12 @@ struct preinterpret_generic<Packet, Packet, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet& a) { return a; }
};
template <typename ComplexPacket>
struct preinterpret_generic<typename unpacket_traits<ComplexPacket>::as_real, ComplexPacket, false> {
using RealPacket = typename unpacket_traits<ComplexPacket>::as_real;
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE RealPacket run(const ComplexPacket& a) { return a.v; }
};
/** \internal \returns reinterpret_cast<Target>(a) */
template <typename Target, typename Packet>
EIGEN_DEVICE_FUNC inline Target preinterpret(const Packet& a) {
@@ -375,7 +381,7 @@ EIGEN_DEVICE_FUNC inline bool pdiv(const bool& a, const bool& b) {
return a && b;
}
// In the generic case, memset to all one bits.
// In the generic packet case, memset to all one bits.
template <typename Packet, typename EnableIf = void>
struct ptrue_impl {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& /*a*/) {
@@ -385,19 +391,16 @@ struct ptrue_impl {
}
};
// Use a value of one for scalars.
template <typename Scalar>
struct ptrue_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value>> {
static EIGEN_DEVICE_FUNC inline Scalar run(const Scalar&) { return Scalar(1); }
};
// For booleans, we can only directly set a valid `bool` value to avoid UB.
template <>
struct ptrue_impl<bool, void> {
static EIGEN_DEVICE_FUNC inline bool run(const bool& /*a*/) { return true; }
};
// For non-trivial scalars, set to Scalar(1) (i.e. a non-zero value).
// Although this is technically not a valid bitmask, the scalar path for pselect
// uses a comparison to zero, so this should still work in most cases. We don't
// have another option, since the scalar type requires initialization.
template <typename T>
struct ptrue_impl<T, std::enable_if_t<is_scalar<T>::value && NumTraits<T>::RequireInitialization>> {
static EIGEN_DEVICE_FUNC inline T run(const T& /*a*/) { return T(1); }
static EIGEN_DEVICE_FUNC inline bool run(const bool&) { return true; }
};
/** \internal \returns one bits. */
@@ -406,7 +409,7 @@ EIGEN_DEVICE_FUNC inline Packet ptrue(const Packet& a) {
return ptrue_impl<Packet>::run(a);
}
// In the general case, memset to zero.
// In the general packet case, memset to zero.
template <typename Packet, typename EnableIf = void>
struct pzero_impl {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& /*a*/) {
@@ -608,7 +611,7 @@ EIGEN_DEVICE_FUNC inline bool pselect<bool>(const bool& cond, const bool& a, con
/** \internal \returns the min or of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, the result is implementation defined. */
template <int NaNPropagation>
template <int NaNPropagation, bool IsInteger>
struct pminmax_impl {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
@@ -619,7 +622,7 @@ struct pminmax_impl {
/** \internal \returns the min or max of \a a and \a b (coeff-wise)
If either \a a or \a b are NaN, NaN is returned. */
template <>
struct pminmax_impl<PropagateNaN> {
struct pminmax_impl<PropagateNaN, false> {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
Packet not_nan_mask_a = pcmp_eq(a, a);
@@ -632,7 +635,7 @@ struct pminmax_impl<PropagateNaN> {
If both \a a and \a b are NaN, NaN is returned.
Equivalent to std::fmin(a, b). */
template <>
struct pminmax_impl<PropagateNumbers> {
struct pminmax_impl<PropagateNumbers, false> {
template <typename Packet, typename Op>
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a, const Packet& b, Op op) {
Packet not_nan_mask_a = pcmp_eq(a, a);
@@ -641,7 +644,7 @@ struct pminmax_impl<PropagateNumbers> {
}
};
#define EIGEN_BINARY_OP_NAN_PROPAGATION(Type, Func) [](const Type& a, const Type& b) { return Func(a, b); }
#define EIGEN_BINARY_OP_NAN_PROPAGATION(Type, Func) [](const Type& aa, const Type& bb) { return Func(aa, bb); }
/** \internal \returns the min of \a a and \a b (coeff-wise).
If \a a or \b b is NaN, the return value is implementation defined. */
@@ -654,7 +657,8 @@ EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) {
NaNPropagation determines the NaN propagation semantics. */
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmin(const Packet& a, const Packet& b) {
return pminmax_impl<NaNPropagation>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmin<Packet>)));
constexpr bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger;
return pminmax_impl<NaNPropagation, IsInteger>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmin<Packet>)));
}
/** \internal \returns the max of \a a and \a b (coeff-wise)
@@ -668,7 +672,8 @@ EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) {
NaNPropagation determines the NaN propagation semantics. */
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmax(const Packet& a, const Packet& b) {
return pminmax_impl<NaNPropagation>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmax<Packet>)));
constexpr bool IsInteger = NumTraits<typename unpacket_traits<Packet>::type>::IsInteger;
return pminmax_impl<NaNPropagation, IsInteger>::run(a, b, EIGEN_BINARY_OP_NAN_PROPAGATION(Packet, (pmax<Packet>)));
}
/** \internal \returns the absolute value of \a a */
@@ -873,17 +878,29 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet plset(const typename unpacket_trait
return a;
}
template <typename Packet, typename EnableIf = void>
struct peven_mask_impl {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet run(const Packet&) {
typedef typename unpacket_traits<Packet>::type Scalar;
const size_t n = unpacket_traits<Packet>::size;
EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Scalar elements[n];
for (size_t i = 0; i < n; ++i) {
memset(elements + i, ((i & 1) == 0 ? 0xff : 0), sizeof(Scalar));
}
return ploadu<Packet>(elements);
}
};
template <typename Scalar>
struct peven_mask_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value>> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar&) { return Scalar(1); }
};
/** \internal \returns a packet with constant coefficients \a a, e.g.: (x, 0, x, 0),
where x is the value of all 1-bits. */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet peven_mask(const Packet& /*a*/) {
typedef typename unpacket_traits<Packet>::type Scalar;
const size_t n = unpacket_traits<Packet>::size;
EIGEN_ALIGN_TO_BOUNDARY(sizeof(Packet)) Scalar elements[n];
for (size_t i = 0; i < n; ++i) {
memset(elements + i, ((i & 1) == 0 ? 0xff : 0), sizeof(Scalar));
}
return ploadu<Packet>(elements);
EIGEN_DEVICE_FUNC inline Packet peven_mask(const Packet& a) {
return peven_mask_impl<Packet>::run(a);
}
/** \internal copy the packet \a from to \a *to, \a to must be properly aligned */
@@ -1244,26 +1261,46 @@ EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_mul(const
template <typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<PropagateFast, Scalar>)));
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<Scalar>)));
}
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<NaNPropagation, Scalar>)));
}
/** \internal \returns the min of the elements of \a a */
/** \internal \returns the max of the elements of \a a */
template <typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<PropagateFast, Scalar>)));
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<Scalar>)));
}
template <int NaNPropagation, typename Packet>
struct predux_min_max_helper_impl {
using Scalar = typename unpacket_traits<Packet>::type;
static constexpr bool UsePredux_ = NaNPropagation == PropagateFast || NumTraits<Scalar>::IsInteger;
template <bool UsePredux = UsePredux_, std::enable_if_t<!UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_min(const Packet& a) {
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmin<NaNPropagation, Scalar>)));
}
template <bool UsePredux = UsePredux_, std::enable_if_t<!UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_max(const Packet& a) {
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<NaNPropagation, Scalar>)));
}
template <bool UsePredux = UsePredux_, std::enable_if_t<UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_min(const Packet& a) {
return predux_min(a);
}
template <bool UsePredux = UsePredux_, std::enable_if_t<UsePredux, bool> = true>
static EIGEN_DEVICE_FUNC inline Scalar run_max(const Packet& a) {
return predux_max(a);
}
};
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_min(const Packet& a) {
return predux_min_max_helper_impl<NaNPropagation, Packet>::run_min(a);
}
template <int NaNPropagation, typename Packet>
EIGEN_DEVICE_FUNC inline typename unpacket_traits<Packet>::type predux_max(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
return predux_helper(a, EIGEN_BINARY_OP_NAN_PROPAGATION(Scalar, (pmax<NaNPropagation, Scalar>)));
return predux_min_max_helper_impl<NaNPropagation, Packet>::run_max(a);
}
#undef EIGEN_BINARY_OP_NAN_PROPAGATION
@@ -1313,20 +1350,20 @@ struct pmadd_impl {
template <typename Scalar>
struct pmadd_impl<Scalar, std::enable_if_t<is_scalar<Scalar>::value && NumTraits<Scalar>::IsSigned>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, c);
return numext::madd<Scalar>(a, b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(a, b, Scalar(-c));
return numext::madd<Scalar>(a, b, Scalar(-c));
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmadd(const Scalar& a, const Scalar& b, const Scalar& c) {
return numext::fma(Scalar(-a), b, c);
return numext::madd<Scalar>(Scalar(-a), b, c);
}
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Scalar pnmsub(const Scalar& a, const Scalar& b, const Scalar& c) {
return -Scalar(numext::fma(a, b, c));
return -Scalar(numext::madd<Scalar>(a, b, c));
}
};
// FMA instructions.
// Multiply-add instructions.
/** \internal \returns a * b + c (coeff-wise) */
template <typename Packet>
EIGEN_DEVICE_FUNC inline Packet pmadd(const Packet& a, const Packet& b, const Packet& c) {
@@ -1565,9 +1602,10 @@ EIGEN_DEVICE_FUNC inline Packet ploaduSegment(const typename unpacket_traits<Pac
using Scalar = typename unpacket_traits<Packet>::type;
constexpr Index PacketSize = unpacket_traits<Packet>::size;
eigen_assert((begin >= 0 && count >= 0 && begin + count <= PacketSize) && "invalid range");
Scalar aux[PacketSize];
memset(static_cast<void*>(aux), 0x00, sizeof(Scalar) * PacketSize);
smart_copy(from + begin, from + begin + count, aux + begin);
Scalar aux[PacketSize] = {};
for (Index k = begin; k < begin + count; k++) {
aux[k] = from[k];
}
return ploadu<Packet>(aux);
}
@@ -1588,7 +1626,9 @@ EIGEN_DEVICE_FUNC inline void pstoreuSegment(Scalar* to, const Packet& from, Ind
eigen_assert((begin >= 0 && count >= 0 && begin + count <= PacketSize) && "invalid range");
Scalar aux[PacketSize];
pstoreu<Scalar, Packet>(aux, from);
smart_copy(aux + begin, aux + begin + count, to + begin);
for (Index k = begin; k < begin + count; k++) {
to[k] = aux[k];
}
}
/** \internal copy the packet \a from in the range [begin, begin + count) to \a *to.

View File

@@ -308,6 +308,12 @@ struct unary_evaluator<IndexedView<ArgType, RowIndices, ColIndices>, IndexBased>
const XprType& m_xpr;
};
// Catch assignments to an IndexedView.
template <typename ArgType, typename RowIndices, typename ColIndices>
struct evaluator_assume_aliasing<IndexedView<ArgType, RowIndices, ColIndices>> {
static const bool value = true;
};
} // end namespace internal
} // end namespace Eigen

View File

@@ -182,10 +182,6 @@ struct imag_ref_retval {
typedef typename NumTraits<Scalar>::Real& type;
};
// implementation in MathFunctionsImpl.h
template <typename Mask, bool is_built_in_float = std::is_floating_point<Mask>::value>
struct scalar_select_mask;
} // namespace internal
namespace numext {
@@ -211,9 +207,9 @@ EIGEN_DEVICE_FUNC inline EIGEN_MATHFUNC_RETVAL(imag, Scalar) imag(const Scalar&
return EIGEN_MATHFUNC_IMPL(imag, Scalar)::run(x);
}
template <typename Scalar, typename Mask>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar select(const Mask& mask, const Scalar& a, const Scalar& b) {
return internal::scalar_select_mask<Mask>::run(mask) ? b : a;
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar select(const Scalar& mask, const Scalar& a, const Scalar& b) {
return numext::is_exactly_zero(mask) ? b : a;
}
} // namespace numext
@@ -945,23 +941,43 @@ struct nearest_integer_impl<Scalar, true> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run_trunc(const Scalar& x) { return x; }
};
// Extra namespace to prevent leaking std::fma into Eigen::internal.
namespace has_fma_detail {
template <typename T, typename EnableIf = void>
struct has_fma_impl : public std::false_type {};
using std::fma;
template <typename T>
struct has_fma_impl<
T, std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>>
: public std::true_type {};
} // namespace has_fma_detail
template <typename T>
struct has_fma : public has_fma_detail::has_fma_impl<T> {};
// Default implementation.
template <typename Scalar, typename Enable = void>
template <typename T, typename Enable = void>
struct fma_impl {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& a, const Scalar& b, const Scalar& c) {
return a * b + c;
static_assert(has_fma<T>::value, "No function fma(...) for type. Please provide an implementation.");
};
// STD or ADL version if it exists.
template <typename T>
struct fma_impl<T, std::enable_if_t<has_fma<T>::value>> {
static EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T run(const T& a, const T& b, const T& c) {
using std::fma;
return fma(a, b, c);
}
};
// ADL version if it exists.
template <typename T>
struct fma_impl<
T,
std::enable_if_t<std::is_same<T, decltype(fma(std::declval<T>(), std::declval<T>(), std::declval<T>()))>::value>> {
static T run(const T& a, const T& b, const T& c) { return fma(a, b, c); }
};
#if defined(EIGEN_GPUCC)
template <>
struct has_fma<float> : public true_type {};
template <>
struct fma_impl<float, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float run(const float& a, const float& b, const float& c) {
@@ -969,6 +985,9 @@ struct fma_impl<float, void> {
}
};
template <>
struct has_fma<double> : public true_type {};
template <>
struct fma_impl<double, void> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double run(const double& a, const double& b, const double& c) {
@@ -977,6 +996,24 @@ struct fma_impl<double, void> {
};
#endif
// Basic multiply-add.
template <typename Scalar, typename EnableIf = void>
struct madd_impl {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
return x * y + z;
}
};
// Use FMA if there is a single CPU instruction.
#ifdef EIGEN_VECTORIZE_FMA
template <typename Scalar>
struct madd_impl<Scalar, std::enable_if_t<has_fma<Scalar>::value>> {
static EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar run(const Scalar& x, const Scalar& y, const Scalar& z) {
return fma_impl<Scalar>::run(x, y, z);
}
};
#endif
} // end namespace internal
/****************************************************************************
@@ -1890,15 +1927,18 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar arithmetic_shift_right(const Scalar
return bit_cast<Scalar, SignedScalar>(bit_cast<SignedScalar, Scalar>(a) >> n);
}
// Use std::fma if available.
using std::fma;
// Otherwise, rely on template implementation.
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar fma(const Scalar& x, const Scalar& y, const Scalar& z) {
return internal::fma_impl<Scalar>::run(x, y, z);
}
// Multiply-add.
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar madd(const Scalar& x, const Scalar& y, const Scalar& z) {
return internal::madd_impl<Scalar>::run(x, y, z);
}
} // end namespace numext
namespace internal {

View File

@@ -28,7 +28,7 @@ namespace internal {
2. If a is zero, approx_a_recip must be infinite with the same sign as a.
3. If a is infinite, approx_a_recip must be zero with the same sign as a.
If the preconditions are satisfied, which they are for for the _*_rcp_ps
If the preconditions are satisfied, which they are for the _*_rcp_ps
instructions on x86, the result has a maximum relative error of 2 ulps,
and correctly handles reciprocals of zero, infinity, and NaN.
*/
@@ -66,7 +66,7 @@ struct generic_reciprocal_newton_step<Packet, 0> {
2. If a is zero, approx_a_recip must be infinite with the same sign as a.
3. If a is infinite, approx_a_recip must be zero with the same sign as a.
If the preconditions are satisfied, which they are for for the _*_rcp_ps
If the preconditions are satisfied, which they are for the _*_rcp_ps
instructions on x86, the result has a maximum relative error of 2 ulps,
and correctly handles zero, infinity, and NaN. Positive denormals are
treated as zero.
@@ -116,7 +116,7 @@ struct generic_rsqrt_newton_step<Packet, 0> {
2. If a is zero, approx_rsqrt must be infinite.
3. If a is infinite, approx_rsqrt must be zero.
If the preconditions are satisfied, which they are for for the _*_rsqrt_ps
If the preconditions are satisfied, which they are for the _*_rsqrt_ps
instructions on x86, the result has a maximum relative error of 2 ulps,
and correctly handles zero and infinity, and NaN. Positive denormal inputs
are treated as zero.
@@ -256,48 +256,6 @@ EIGEN_DEVICE_FUNC ComplexT complex_log(const ComplexT& z) {
return ComplexT(numext::log(a), b);
}
// For generic scalars, use ternary select.
template <typename Mask>
struct scalar_select_mask<Mask, /*is_built_in_float*/ false> {
static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) { return numext::is_exactly_zero(mask); }
};
// For built-in float mask, bitcast the mask to its integer counterpart and use ternary select.
template <typename Mask>
struct scalar_select_mask<Mask, /*is_built_in_float*/ true> {
using IntegerType = typename numext::get_integer_by_size<sizeof(Mask)>::unsigned_type;
static EIGEN_DEVICE_FUNC inline bool run(const Mask& mask) {
return numext::is_exactly_zero(numext::bit_cast<IntegerType>(std::abs(mask)));
}
};
template <int Size = sizeof(long double)>
struct ldbl_select_mask {
static constexpr int MantissaDigits = std::numeric_limits<long double>::digits;
static constexpr int NumBytes = (MantissaDigits == 64 ? 80 : 128) / CHAR_BIT;
static EIGEN_DEVICE_FUNC inline bool run(const long double& mask) {
const uint8_t* mask_bytes = reinterpret_cast<const uint8_t*>(&mask);
for (Index i = 0; i < NumBytes; i++) {
if (mask_bytes[i] != 0) return false;
}
return true;
}
};
template <>
struct ldbl_select_mask<sizeof(double)> : scalar_select_mask<double> {};
template <>
struct scalar_select_mask<long double, true> : ldbl_select_mask<> {};
template <typename RealMask>
struct scalar_select_mask<std::complex<RealMask>, false> {
using impl = scalar_select_mask<RealMask>;
static EIGEN_DEVICE_FUNC inline bool run(const std::complex<RealMask>& mask) {
return impl::run(numext::real(mask)) && impl::run(numext::imag(mask));
}
};
} // end namespace internal
} // end namespace Eigen

View File

@@ -95,9 +95,22 @@ struct default_max_digits10_impl<T, false, true> // Integer
} // end namespace internal
namespace numext {
/** \internal bit-wise cast without changing the underlying bit representation. */
// TODO: Replace by std::bit_cast (available in C++20)
/** \internal bit-wise cast without changing the underlying bit representation. */
#if defined(__cpp_lib_bit_cast) && __cpp_lib_bit_cast >= 201806L
template <typename Tgt, typename Src>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr Tgt bit_cast(const Src& src) {
return std::bit_cast<Tgt>(src);
}
#elif EIGEN_HAS_BUILTIN(__builtin_bit_cast)
template <typename Tgt, typename Src>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC constexpr Tgt bit_cast(const Src& src) {
EIGEN_STATIC_ASSERT(std::is_trivially_copyable<Src>::value, THIS_TYPE_IS_NOT_SUPPORTED)
EIGEN_STATIC_ASSERT(std::is_trivially_copyable<Tgt>::value, THIS_TYPE_IS_NOT_SUPPORTED)
EIGEN_STATIC_ASSERT(sizeof(Src) == sizeof(Tgt), THIS_TYPE_IS_NOT_SUPPORTED)
return __builtin_bit_cast(Tgt, src);
}
#else
template <typename Tgt, typename Src>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Tgt bit_cast(const Src& src) {
// The behaviour of memcpy is not specified for non-trivially copyable types
@@ -113,6 +126,7 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Tgt bit_cast(const Src& src) {
memcpy(static_cast<void*>(&tgt), static_cast<const void*>(&staged), sizeof(Tgt));
return tgt;
}
#endif
} // namespace numext
// clang-format off

View File

@@ -468,17 +468,17 @@ class PermutationWrapper : public PermutationBase<PermutationWrapper<IndicesType
/** \returns the matrix with the permutation applied to the columns.
*/
template <typename MatrixDerived, typename PermutationDerived>
EIGEN_DEVICE_FUNC const Product<MatrixDerived, PermutationDerived, AliasFreeProduct> operator*(
EIGEN_DEVICE_FUNC const Product<MatrixDerived, PermutationDerived, DefaultProduct> operator*(
const MatrixBase<MatrixDerived>& matrix, const PermutationBase<PermutationDerived>& permutation) {
return Product<MatrixDerived, PermutationDerived, AliasFreeProduct>(matrix.derived(), permutation.derived());
return Product<MatrixDerived, PermutationDerived, DefaultProduct>(matrix.derived(), permutation.derived());
}
/** \returns the matrix with the permutation applied to the rows.
*/
template <typename PermutationDerived, typename MatrixDerived>
EIGEN_DEVICE_FUNC const Product<PermutationDerived, MatrixDerived, AliasFreeProduct> operator*(
EIGEN_DEVICE_FUNC const Product<PermutationDerived, MatrixDerived, DefaultProduct> operator*(
const PermutationBase<PermutationDerived>& permutation, const MatrixBase<MatrixDerived>& matrix) {
return Product<PermutationDerived, MatrixDerived, AliasFreeProduct>(permutation.derived(), matrix.derived());
return Product<PermutationDerived, MatrixDerived, DefaultProduct>(permutation.derived(), matrix.derived());
}
template <typename PermutationType>
@@ -520,16 +520,16 @@ class InverseImpl<PermutationType, PermutationStorage> : public EigenBase<Invers
/** \returns the matrix with the inverse permutation applied to the columns.
*/
template <typename OtherDerived>
friend const Product<OtherDerived, InverseType, AliasFreeProduct> operator*(const MatrixBase<OtherDerived>& matrix,
const InverseType& trPerm) {
return Product<OtherDerived, InverseType, AliasFreeProduct>(matrix.derived(), trPerm.derived());
friend const Product<OtherDerived, InverseType, DefaultProduct> operator*(const MatrixBase<OtherDerived>& matrix,
const InverseType& trPerm) {
return Product<OtherDerived, InverseType, DefaultProduct>(matrix.derived(), trPerm.derived());
}
/** \returns the matrix with the inverse permutation applied to the rows.
*/
template <typename OtherDerived>
const Product<InverseType, OtherDerived, AliasFreeProduct> operator*(const MatrixBase<OtherDerived>& matrix) const {
return Product<InverseType, OtherDerived, AliasFreeProduct>(derived(), matrix.derived());
const Product<InverseType, OtherDerived, DefaultProduct> operator*(const MatrixBase<OtherDerived>& matrix) const {
return Product<InverseType, OtherDerived, DefaultProduct>(derived(), matrix.derived());
}
};

View File

@@ -846,7 +846,7 @@ struct generic_product_impl<Lhs, Rhs, SelfAdjointShape, DenseShape, ProductTag>
template <typename Dest>
static EIGEN_DEVICE_FUNC void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::IsVectorAtCompileTime>::run(
selfadjoint_product_impl<typename Lhs::MatrixType, Lhs::Mode, false, Rhs, 0, Rhs::ColsAtCompileTime == 1>::run(
dst, lhs.nestedExpression(), rhs, alpha);
}
};
@@ -858,7 +858,7 @@ struct generic_product_impl<Lhs, Rhs, DenseShape, SelfAdjointShape, ProductTag>
template <typename Dest>
static void scaleAndAddTo(Dest& dst, const Lhs& lhs, const Rhs& rhs, const Scalar& alpha) {
selfadjoint_product_impl<Lhs, 0, Lhs::IsVectorAtCompileTime, typename Rhs::MatrixType, Rhs::Mode, false>::run(
selfadjoint_product_impl<Lhs, 0, Lhs::RowsAtCompileTime == 1, typename Rhs::MatrixType, Rhs::Mode, false>::run(
dst, lhs, rhs.nestedExpression(), alpha);
}
};

View File

@@ -131,8 +131,15 @@ struct random_longdouble_impl {
uint64_t randomBits[2];
long double result = 2.0L;
memcpy(&randomBits, &result, Size);
#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
randomBits[0] |= getRandomBits<uint64_t>(numLowBits);
randomBits[1] |= getRandomBits<uint64_t>(numHighBits);
#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
randomBits[0] |= getRandomBits<uint64_t>(numHighBits);
randomBits[1] |= getRandomBits<uint64_t>(numLowBits);
#else
#error Unexpected or undefined __BYTE_ORDER__
#endif
memcpy(&result, &randomBits, Size);
result -= 3.0L;
return result;

View File

@@ -0,0 +1,250 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 Charlie Schlosser <cs.schlosser@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_REALVIEW_H
#define EIGEN_REALVIEW_H
// IWYU pragma: private
#include "./InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
// Vectorized assignment to RealView requires array-oriented access to the real and imaginary components.
// From https://en.cppreference.com/w/cpp/numeric/complex.html:
// For any pointer to an element of an array of std::complex<T> named p and any valid array index i,
// reinterpret_cast<T*>(p)[2 * i] is the real part of the complex number p[i], and
// reinterpret_cast<T*>(p)[2 * i + 1] is the imaginary part of the complex number p[i].
template <typename ComplexScalar>
struct complex_array_access : std::false_type {};
template <>
struct complex_array_access<std::complex<float>> : std::true_type {};
template <>
struct complex_array_access<std::complex<double>> : std::true_type {};
template <>
struct complex_array_access<std::complex<long double>> : std::true_type {};
template <typename Xpr>
struct traits<RealView<Xpr>> : public traits<Xpr> {
template <typename T>
static constexpr int double_size(T size, bool times_two) {
int size_as_int = int(size);
if (size_as_int == Dynamic) return Dynamic;
return times_two ? (2 * size_as_int) : size_as_int;
}
using Base = traits<Xpr>;
using ComplexScalar = typename Base::Scalar;
using Scalar = typename NumTraits<ComplexScalar>::Real;
static constexpr int ActualDirectAccessBit = complex_array_access<ComplexScalar>::value ? DirectAccessBit : 0;
static constexpr int ActualPacketAccessBit = packet_traits<Scalar>::Vectorizable ? PacketAccessBit : 0;
static constexpr int FlagMask =
ActualDirectAccessBit | ActualPacketAccessBit | HereditaryBits | LinearAccessBit | LvalueBit;
static constexpr int BaseFlags = int(evaluator<Xpr>::Flags) | int(Base::Flags);
static constexpr int Flags = BaseFlags & FlagMask;
static constexpr bool IsRowMajor = Flags & RowMajorBit;
static constexpr int RowsAtCompileTime = double_size(Base::RowsAtCompileTime, !IsRowMajor);
static constexpr int ColsAtCompileTime = double_size(Base::ColsAtCompileTime, IsRowMajor);
static constexpr int SizeAtCompileTime = size_at_compile_time(RowsAtCompileTime, ColsAtCompileTime);
static constexpr int MaxRowsAtCompileTime = double_size(Base::MaxRowsAtCompileTime, !IsRowMajor);
static constexpr int MaxColsAtCompileTime = double_size(Base::MaxColsAtCompileTime, IsRowMajor);
static constexpr int MaxSizeAtCompileTime = size_at_compile_time(MaxRowsAtCompileTime, MaxColsAtCompileTime);
static constexpr int OuterStrideAtCompileTime = double_size(outer_stride_at_compile_time<Xpr>::ret, true);
static constexpr int InnerStrideAtCompileTime = inner_stride_at_compile_time<Xpr>::ret;
};
template <typename Xpr>
struct evaluator<RealView<Xpr>> : private evaluator<Xpr> {
using BaseEvaluator = evaluator<Xpr>;
using XprType = RealView<Xpr>;
using ExpressionTraits = traits<XprType>;
using ComplexScalar = typename ExpressionTraits::ComplexScalar;
using ComplexCoeffReturnType = typename BaseEvaluator::CoeffReturnType;
using Scalar = typename ExpressionTraits::Scalar;
static constexpr bool IsRowMajor = ExpressionTraits::IsRowMajor;
static constexpr int Flags = ExpressionTraits::Flags;
static constexpr int CoeffReadCost = BaseEvaluator::CoeffReadCost;
static constexpr int Alignment = BaseEvaluator::Alignment;
EIGEN_DEVICE_FUNC explicit evaluator(XprType realView) : BaseEvaluator(realView.m_xpr) {}
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<!Enable>>
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index row, Index col) const {
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col);
Index p = (IsRowMajor ? col : row) & 1;
return p ? numext::real(cscalar) : numext::imag(cscalar);
}
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<Enable>>
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index row, Index col) const {
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col);
Index p = (IsRowMajor ? col : row) & 1;
return reinterpret_cast<const Scalar(&)[2]>(cscalar)[p];
}
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index row, Index col) {
ComplexScalar& cscalar = BaseEvaluator::coeffRef(IsRowMajor ? row : row / 2, IsRowMajor ? col / 2 : col);
Index p = (IsRowMajor ? col : row) & 1;
return reinterpret_cast<Scalar(&)[2]>(cscalar)[p];
}
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<!Enable>>
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar coeff(Index index) const {
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2);
Index p = index & 1;
return p ? numext::real(cscalar) : numext::imag(cscalar);
}
template <bool Enable = std::is_reference<ComplexCoeffReturnType>::value, typename = std::enable_if_t<Enable>>
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE const Scalar& coeff(Index index) const {
ComplexCoeffReturnType cscalar = BaseEvaluator::coeff(index / 2);
Index p = index & 1;
return reinterpret_cast<const Scalar(&)[2]>(cscalar)[p];
}
constexpr EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar& coeffRef(Index index) {
ComplexScalar& cscalar = BaseEvaluator::coeffRef(index / 2);
Index p = index & 1;
return reinterpret_cast<Scalar(&)[2]>(cscalar)[p];
}
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index row, Index col) const {
constexpr int RealPacketSize = unpacket_traits<PacketType>::size;
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
MISSING COMPATIBLE COMPLEX PACKET TYPE)
eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even");
Index crow = IsRowMajor ? row : row / 2;
Index ccol = IsRowMajor ? col / 2 : col;
ComplexPacket cpacket = BaseEvaluator::template packet<LoadMode, ComplexPacket>(crow, ccol);
return preinterpret<PacketType, ComplexPacket>(cpacket);
}
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packet(Index index) const {
constexpr int RealPacketSize = unpacket_traits<PacketType>::size;
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
MISSING COMPATIBLE COMPLEX PACKET TYPE)
eigen_assert((index % 2 == 0) && "the index must be even");
Index cindex = index / 2;
ComplexPacket cpacket = BaseEvaluator::template packet<LoadMode, ComplexPacket>(cindex);
return preinterpret<PacketType, ComplexPacket>(cpacket);
}
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index row, Index col, Index begin, Index count) const {
constexpr int RealPacketSize = unpacket_traits<PacketType>::size;
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
MISSING COMPATIBLE COMPLEX PACKET TYPE)
eigen_assert(((IsRowMajor ? col : row) % 2 == 0) && "the inner index must be even");
eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even");
Index crow = IsRowMajor ? row : row / 2;
Index ccol = IsRowMajor ? col / 2 : col;
Index cbegin = begin / 2;
Index ccount = count / 2;
ComplexPacket cpacket = BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(crow, ccol, cbegin, ccount);
return preinterpret<PacketType, ComplexPacket>(cpacket);
}
template <int LoadMode, typename PacketType>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE PacketType packetSegment(Index index, Index begin, Index count) const {
constexpr int RealPacketSize = unpacket_traits<PacketType>::size;
using ComplexPacket = typename find_packet_by_size<ComplexScalar, RealPacketSize / 2>::type;
EIGEN_STATIC_ASSERT((find_packet_by_size<ComplexScalar, RealPacketSize / 2>::value),
MISSING COMPATIBLE COMPLEX PACKET TYPE)
eigen_assert((index % 2 == 0) && "the index must be even");
eigen_assert((begin % 2 == 0) && (count % 2 == 0) && "begin and count must be even");
Index cindex = index / 2;
Index cbegin = begin / 2;
Index ccount = count / 2;
ComplexPacket cpacket = BaseEvaluator::template packetSegment<LoadMode, ComplexPacket>(cindex, cbegin, ccount);
return preinterpret<PacketType, ComplexPacket>(cpacket);
}
};
} // namespace internal
template <typename Xpr>
class RealView : public internal::dense_xpr_base<RealView<Xpr>>::type {
using ExpressionTraits = internal::traits<RealView>;
EIGEN_STATIC_ASSERT(NumTraits<typename Xpr::Scalar>::IsComplex, SCALAR MUST BE COMPLEX)
public:
using Scalar = typename ExpressionTraits::Scalar;
using Nested = RealView;
EIGEN_DEVICE_FUNC explicit RealView(Xpr& xpr) : m_xpr(xpr) {}
EIGEN_DEVICE_FUNC constexpr Index rows() const noexcept { return Xpr::IsRowMajor ? m_xpr.rows() : 2 * m_xpr.rows(); }
EIGEN_DEVICE_FUNC constexpr Index cols() const noexcept { return Xpr::IsRowMajor ? 2 * m_xpr.cols() : m_xpr.cols(); }
EIGEN_DEVICE_FUNC constexpr Index size() const noexcept { return 2 * m_xpr.size(); }
EIGEN_DEVICE_FUNC constexpr Index innerStride() const noexcept { return m_xpr.innerStride(); }
EIGEN_DEVICE_FUNC constexpr Index outerStride() const noexcept { return 2 * m_xpr.outerStride(); }
EIGEN_DEVICE_FUNC void resize(Index rows, Index cols) {
m_xpr.resize(Xpr::IsRowMajor ? rows : rows / 2, Xpr::IsRowMajor ? cols / 2 : cols);
}
EIGEN_DEVICE_FUNC void resize(Index size) { m_xpr.resize(size / 2); }
EIGEN_DEVICE_FUNC Scalar* data() { return reinterpret_cast<Scalar*>(m_xpr.data()); }
EIGEN_DEVICE_FUNC const Scalar* data() const { return reinterpret_cast<const Scalar*>(m_xpr.data()); }
EIGEN_DEVICE_FUNC RealView(const RealView&) = default;
EIGEN_DEVICE_FUNC RealView& operator=(const RealView& other);
template <typename OtherDerived>
EIGEN_DEVICE_FUNC RealView& operator=(const RealView<OtherDerived>& other);
template <typename OtherDerived>
EIGEN_DEVICE_FUNC RealView& operator=(const DenseBase<OtherDerived>& other);
protected:
friend struct internal::evaluator<RealView<Xpr>>;
Xpr& m_xpr;
};
template <typename Xpr>
EIGEN_DEVICE_FUNC RealView<Xpr>& RealView<Xpr>::operator=(const RealView& other) {
internal::call_assignment(*this, other);
return *this;
}
template <typename Xpr>
template <typename OtherDerived>
EIGEN_DEVICE_FUNC RealView<Xpr>& RealView<Xpr>::operator=(const RealView<OtherDerived>& other) {
internal::call_assignment(*this, other);
return *this;
}
template <typename Xpr>
template <typename OtherDerived>
EIGEN_DEVICE_FUNC RealView<Xpr>& RealView<Xpr>::operator=(const DenseBase<OtherDerived>& other) {
internal::call_assignment(*this, other.derived());
return *this;
}
template <typename Derived>
EIGEN_DEVICE_FUNC typename DenseBase<Derived>::RealViewReturnType DenseBase<Derived>::realView() {
return RealViewReturnType(derived());
}
template <typename Derived>
EIGEN_DEVICE_FUNC typename DenseBase<Derived>::ConstRealViewReturnType DenseBase<Derived>::realView() const {
return ConstRealViewReturnType(derived());
}
} // namespace Eigen
#endif // EIGEN_REALVIEW_H

View File

@@ -15,7 +15,7 @@
namespace Eigen {
/** \class Select
/** \typedef Select
* \ingroup Core_Module
*
* \brief Expression of a coefficient wise version of the C++ ternary operator ?:
@@ -24,73 +24,16 @@ namespace Eigen {
* \tparam ThenMatrixType the type of the \em then expression
* \tparam ElseMatrixType the type of the \em else expression
*
* This class represents an expression of a coefficient wise version of the C++ ternary operator ?:.
* This type represents an expression of a coefficient wise version of the C++ ternary operator ?:.
* It is the return type of DenseBase::select() and most of the time this is the only way it is used.
*
* \sa DenseBase::select(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&) const
*/
namespace internal {
template <typename ConditionMatrixType, typename ThenMatrixType, typename ElseMatrixType>
struct traits<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType> > : traits<ThenMatrixType> {
typedef typename traits<ThenMatrixType>::Scalar Scalar;
typedef Dense StorageKind;
typedef typename traits<ThenMatrixType>::XprKind XprKind;
typedef typename ConditionMatrixType::Nested ConditionMatrixNested;
typedef typename ThenMatrixType::Nested ThenMatrixNested;
typedef typename ElseMatrixType::Nested ElseMatrixNested;
enum {
RowsAtCompileTime = ConditionMatrixType::RowsAtCompileTime,
ColsAtCompileTime = ConditionMatrixType::ColsAtCompileTime,
MaxRowsAtCompileTime = ConditionMatrixType::MaxRowsAtCompileTime,
MaxColsAtCompileTime = ConditionMatrixType::MaxColsAtCompileTime,
Flags = (unsigned int)ThenMatrixType::Flags & ElseMatrixType::Flags & RowMajorBit
};
};
} // namespace internal
template <typename ConditionMatrixType, typename ThenMatrixType, typename ElseMatrixType>
class Select : public internal::dense_xpr_base<Select<ConditionMatrixType, ThenMatrixType, ElseMatrixType> >::type,
internal::no_assignment_operator {
public:
typedef typename internal::dense_xpr_base<Select>::type Base;
EIGEN_DENSE_PUBLIC_INTERFACE(Select)
inline EIGEN_DEVICE_FUNC Select(const ConditionMatrixType& a_conditionMatrix, const ThenMatrixType& a_thenMatrix,
const ElseMatrixType& a_elseMatrix)
: m_condition(a_conditionMatrix), m_then(a_thenMatrix), m_else(a_elseMatrix) {
eigen_assert(m_condition.rows() == m_then.rows() && m_condition.rows() == m_else.rows());
eigen_assert(m_condition.cols() == m_then.cols() && m_condition.cols() == m_else.cols());
}
EIGEN_DEVICE_FUNC constexpr Index rows() const noexcept { return m_condition.rows(); }
EIGEN_DEVICE_FUNC constexpr Index cols() const noexcept { return m_condition.cols(); }
inline EIGEN_DEVICE_FUNC const Scalar coeff(Index i, Index j) const {
if (m_condition.coeff(i, j))
return m_then.coeff(i, j);
else
return m_else.coeff(i, j);
}
inline EIGEN_DEVICE_FUNC const Scalar coeff(Index i) const {
if (m_condition.coeff(i))
return m_then.coeff(i);
else
return m_else.coeff(i);
}
inline EIGEN_DEVICE_FUNC const ConditionMatrixType& conditionMatrix() const { return m_condition; }
inline EIGEN_DEVICE_FUNC const ThenMatrixType& thenMatrix() const { return m_then; }
inline EIGEN_DEVICE_FUNC const ElseMatrixType& elseMatrix() const { return m_else; }
protected:
typename ConditionMatrixType::Nested m_condition;
typename ThenMatrixType::Nested m_then;
typename ElseMatrixType::Nested m_else;
};
using Select = CwiseTernaryOp<internal::scalar_boolean_select_op<typename DenseBase<ThenMatrixType>::Scalar,
typename DenseBase<ElseMatrixType>::Scalar,
typename DenseBase<ConditionMatrixType>::Scalar>,
ThenMatrixType, ElseMatrixType, ConditionMatrixType>;
/** \returns a matrix where each coefficient (i,j) is equal to \a thenMatrix(i,j)
* if \c *this(i,j) != Scalar(0), and \a elseMatrix(i,j) otherwise.
@@ -98,7 +41,7 @@ class Select : public internal::dense_xpr_base<Select<ConditionMatrixType, ThenM
* Example: \include MatrixBase_select.cpp
* Output: \verbinclude MatrixBase_select.out
*
* \sa DenseBase::bitwiseSelect(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&)
* \sa typedef Select
*/
template <typename Derived>
template <typename ThenDerived, typename ElseDerived>
@@ -107,15 +50,12 @@ inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
typename DenseBase<Derived>::Scalar>,
ThenDerived, ElseDerived, Derived>
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix, const DenseBase<ElseDerived>& elseMatrix) const {
using Op = internal::scalar_boolean_select_op<typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar, Scalar>;
return CwiseTernaryOp<Op, ThenDerived, ElseDerived, Derived>(thenMatrix.derived(), elseMatrix.derived(), derived(),
Op());
return Select<Derived, ThenDerived, ElseDerived>(thenMatrix.derived(), elseMatrix.derived(), derived());
}
/** Version of DenseBase::select(const DenseBase&, const DenseBase&) with
* the \em else expression being a scalar value.
*
* \sa DenseBase::booleanSelect(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&) const, class Select
* \sa typedef Select
*/
template <typename Derived>
template <typename ThenDerived>
@@ -126,15 +66,13 @@ inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
DenseBase<Derived>::select(const DenseBase<ThenDerived>& thenMatrix,
const typename DenseBase<ThenDerived>::Scalar& elseScalar) const {
using ElseConstantType = typename DenseBase<ThenDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<typename DenseBase<ThenDerived>::Scalar,
typename DenseBase<ThenDerived>::Scalar, Scalar>;
return CwiseTernaryOp<Op, ThenDerived, ElseConstantType, Derived>(
thenMatrix.derived(), ElseConstantType(rows(), cols(), elseScalar), derived(), Op());
return Select<Derived, ThenDerived, ElseConstantType>(thenMatrix.derived(),
ElseConstantType(rows(), cols(), elseScalar), derived());
}
/** Version of DenseBase::select(const DenseBase&, const DenseBase&) with
* the \em then expression being a scalar value.
*
* \sa DenseBase::booleanSelect(const DenseBase<ThenDerived>&, const DenseBase<ElseDerived>&) const, class Select
* \sa typedef Select
*/
template <typename Derived>
template <typename ElseDerived>
@@ -145,10 +83,8 @@ inline EIGEN_DEVICE_FUNC CwiseTernaryOp<
DenseBase<Derived>::select(const typename DenseBase<ElseDerived>::Scalar& thenScalar,
const DenseBase<ElseDerived>& elseMatrix) const {
using ThenConstantType = typename DenseBase<ElseDerived>::ConstantReturnType;
using Op = internal::scalar_boolean_select_op<typename DenseBase<ElseDerived>::Scalar,
typename DenseBase<ElseDerived>::Scalar, Scalar>;
return CwiseTernaryOp<Op, ThenConstantType, ElseDerived, Derived>(ThenConstantType(rows(), cols(), thenScalar),
elseMatrix.derived(), derived(), Op());
return Select<Derived, ThenConstantType, ElseDerived>(ThenConstantType(rows(), cols(), thenScalar),
elseMatrix.derived(), derived());
}
} // end namespace Eigen

View File

@@ -15,19 +15,33 @@
namespace Eigen {
// TODO generalize the scalar type of 'other'
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::operator*=(const Scalar& other) {
internal::call_assignment(this->derived(), PlainObject::Constant(rows(), cols(), other),
internal::mul_assign_op<Scalar, Scalar>());
using ConstantExpr = typename internal::plain_constant_type<Derived, Scalar>::type;
using Op = internal::mul_assign_op<Scalar>;
internal::call_assignment(derived(), ConstantExpr(rows(), cols(), other), Op());
return derived();
}
template <typename Derived>
template <bool Enable, typename>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::operator*=(const RealScalar& other) {
realView() *= other;
return derived();
}
template <typename Derived>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::operator/=(const Scalar& other) {
internal::call_assignment(this->derived(), PlainObject::Constant(rows(), cols(), other),
internal::div_assign_op<Scalar, Scalar>());
using ConstantExpr = typename internal::plain_constant_type<Derived, Scalar>::type;
using Op = internal::div_assign_op<Scalar>;
internal::call_assignment(derived(), ConstantExpr(rows(), cols(), other), Op());
return derived();
}
template <typename Derived>
template <bool Enable, typename>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Derived& DenseBase<Derived>::operator/=(const RealScalar& other) {
realView() /= other;
return derived();
}

View File

@@ -146,6 +146,22 @@ struct member_redux {
const BinaryOp& binaryFunc() const { return m_functor; }
const BinaryOp m_functor;
};
template <typename Scalar>
struct scalar_replace_zero_with_one_op {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar operator()(const Scalar& x) const {
return numext::is_exactly_zero(x) ? Scalar(1) : x;
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
return pselect(pcmp_eq(x, pzero(x)), pset1<Packet>(Scalar(1)), x);
}
};
template <typename Scalar>
struct functor_traits<scalar_replace_zero_with_one_op<Scalar>> {
enum { Cost = 1, PacketAccess = packet_traits<Scalar>::HasCmp };
};
} // namespace internal
/** \class VectorwiseOp
@@ -190,9 +206,7 @@ class VectorwiseOp {
public:
typedef typename ExpressionType::Scalar Scalar;
typedef typename ExpressionType::RealScalar RealScalar;
typedef Eigen::Index Index; ///< \deprecated since Eigen 3.3
typedef typename internal::ref_selector<ExpressionType>::non_const_type ExpressionTypeNested;
typedef internal::remove_all_t<ExpressionTypeNested> ExpressionTypeNestedCleaned;
typedef internal::remove_all_t<ExpressionType> ExpressionTypeCleaned;
template <template <typename OutScalar, typename InputScalar> class Functor, typename ReturnScalar = Scalar>
struct ReturnType {
@@ -331,7 +345,7 @@ class VectorwiseOp {
typedef typename ReturnType<internal::member_minCoeff>::Type MinCoeffReturnType;
typedef typename ReturnType<internal::member_maxCoeff>::Type MaxCoeffReturnType;
typedef PartialReduxExpr<const CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const ExpressionTypeNestedCleaned>,
typedef PartialReduxExpr<const CwiseUnaryOp<internal::scalar_abs2_op<Scalar>, const ExpressionTypeCleaned>,
internal::member_sum<RealScalar, RealScalar>, Direction>
SquaredNormReturnType;
typedef CwiseUnaryOp<internal::scalar_sqrt_op<RealScalar>, const SquaredNormReturnType> NormReturnType;
@@ -582,7 +596,7 @@ class VectorwiseOp {
/** Returns the expression of the sum of the vector \a other to each subvector of \c *this */
template <typename OtherDerived>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC
CwiseBinaryOp<internal::scalar_sum_op<Scalar, typename OtherDerived::Scalar>, const ExpressionTypeNestedCleaned,
CwiseBinaryOp<internal::scalar_sum_op<Scalar, typename OtherDerived::Scalar>, const ExpressionTypeCleaned,
const typename ExtendedType<OtherDerived>::Type>
operator+(const DenseBase<OtherDerived>& other) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
@@ -593,7 +607,7 @@ class VectorwiseOp {
/** Returns the expression of the difference between each subvector of \c *this and the vector \a other */
template <typename OtherDerived>
EIGEN_DEVICE_FUNC CwiseBinaryOp<internal::scalar_difference_op<Scalar, typename OtherDerived::Scalar>,
const ExpressionTypeNestedCleaned, const typename ExtendedType<OtherDerived>::Type>
const ExpressionTypeCleaned, const typename ExtendedType<OtherDerived>::Type>
operator-(const DenseBase<OtherDerived>& other) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_SAME_XPR_KIND(ExpressionType, OtherDerived)
@@ -604,7 +618,7 @@ class VectorwiseOp {
* by the corresponding subvector of \c *this */
template <typename OtherDerived>
EIGEN_DEVICE_FUNC CwiseBinaryOp<internal::scalar_product_op<Scalar, typename OtherDerived::Scalar>,
const ExpressionTypeNestedCleaned, const typename ExtendedType<OtherDerived>::Type>
const ExpressionTypeCleaned, const typename ExtendedType<OtherDerived>::Type>
operator*(const DenseBase<OtherDerived>& other) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_ARRAYXPR(ExpressionType)
@@ -616,7 +630,7 @@ class VectorwiseOp {
* subvector of \c *this by the vector \a other */
template <typename OtherDerived>
EIGEN_DEVICE_FUNC CwiseBinaryOp<internal::scalar_quotient_op<Scalar, typename OtherDerived::Scalar>,
const ExpressionTypeNestedCleaned, const typename ExtendedType<OtherDerived>::Type>
const ExpressionTypeCleaned, const typename ExtendedType<OtherDerived>::Type>
operator/(const DenseBase<OtherDerived>& other) const {
EIGEN_STATIC_ASSERT_VECTOR_ONLY(OtherDerived)
EIGEN_STATIC_ASSERT_ARRAYXPR(ExpressionType)
@@ -624,18 +638,28 @@ class VectorwiseOp {
return m_matrix / extendedTo(other.derived());
}
using Normalized_NonzeroNormType =
CwiseUnaryOp<internal::scalar_replace_zero_with_one_op<Scalar>, const NormReturnType>;
using NormalizedReturnType = CwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const ExpressionTypeCleaned,
const typename OppositeExtendedType<Normalized_NonzeroNormType>::Type>;
/** \returns an expression where each column (or row) of the referenced matrix are normalized.
* The referenced matrix is \b not modified.
*
* \warning If the input columns (or rows) are too small (i.e., their norm equals to 0), they remain unchanged in the
* resulting expression.
*
* \sa MatrixBase::normalized(), normalize()
*/
EIGEN_DEVICE_FUNC CwiseBinaryOp<internal::scalar_quotient_op<Scalar>, const ExpressionTypeNestedCleaned,
const typename OppositeExtendedType<NormReturnType>::Type>
normalized() const {
return m_matrix.cwiseQuotient(extendedToOpposite(this->norm()));
EIGEN_DEVICE_FUNC NormalizedReturnType normalized() const {
return m_matrix.cwiseQuotient(extendedToOpposite(Normalized_NonzeroNormType(this->norm())));
}
/** Normalize in-place each row or columns of the referenced matrix.
* \sa MatrixBase::normalize(), normalized()
*
* \warning If the input columns (or rows) are too small (i.e., their norm equals to 0), they are left unchanged.
*
* \sa MatrixBase::normalized(), normalize()
*/
EIGEN_DEVICE_FUNC void normalize() { m_matrix = this->normalized(); }
@@ -679,7 +703,7 @@ class VectorwiseOp {
protected:
EIGEN_DEVICE_FUNC Index redux_length() const { return Direction == Vertical ? m_matrix.rows() : m_matrix.cols(); }
ExpressionTypeNested m_matrix;
ExpressionType& m_matrix;
};
// const colwise moved to DenseBase.h due to CUDA compiler bug

View File

@@ -384,173 +384,6 @@ EIGEN_DEVICE_FUNC void DenseBase<Derived>::visit(Visitor& visitor) const {
namespace internal {
/** \internal
* \brief Base class to implement min and max visitors
*/
template <typename Derived>
struct coeff_visitor {
// default initialization to avoid countless invalid maybe-uninitialized warnings by gcc
EIGEN_DEVICE_FUNC coeff_visitor() : row(-1), col(-1), res(0) {}
typedef typename Derived::Scalar Scalar;
Index row, col;
Scalar res;
EIGEN_DEVICE_FUNC inline void init(const Scalar& value, Index i, Index j) {
res = value;
row = i;
col = j;
}
};
template <typename Scalar, int NaNPropagation, bool is_min = true>
struct minmax_compare {
typedef typename packet_traits<Scalar>::type Packet;
static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a < b; }
static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_min<NaNPropagation>(p); }
};
template <typename Scalar, int NaNPropagation>
struct minmax_compare<Scalar, NaNPropagation, false> {
typedef typename packet_traits<Scalar>::type Packet;
static EIGEN_DEVICE_FUNC inline bool compare(Scalar a, Scalar b) { return a > b; }
static EIGEN_DEVICE_FUNC inline Scalar predux(const Packet& p) { return predux_max<NaNPropagation>(p); }
};
// Default implementation used by non-floating types, where we do not
// need special logic for NaN handling.
template <typename Derived, bool is_min, int NaNPropagation,
bool isInt = NumTraits<typename Derived::Scalar>::IsInteger>
struct minmax_coeff_visitor : coeff_visitor<Derived> {
using Scalar = typename Derived::Scalar;
using Packet = typename packet_traits<Scalar>::type;
using Comparator = minmax_compare<Scalar, NaNPropagation, is_min>;
static constexpr Index PacketSize = packet_traits<Scalar>::size;
EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index i, Index j) {
if (Comparator::compare(value, this->res)) {
this->res = value;
this->row = i;
this->col = j;
}
}
EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index i, Index j) {
Scalar value = Comparator::predux(p);
if (Comparator::compare(value, this->res)) {
const Packet range = preverse(plset<Packet>(Scalar(1)));
Packet mask = pcmp_eq(pset1<Packet>(value), p);
Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
this->res = value;
this->row = Derived::IsRowMajor ? i : i + max_idx;
this->col = Derived::IsRowMajor ? j + max_idx : j;
}
}
EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) {
Scalar value = Comparator::predux(p);
const Packet range = preverse(plset<Packet>(Scalar(1)));
Packet mask = pcmp_eq(pset1<Packet>(value), p);
Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
this->res = value;
this->row = Derived::IsRowMajor ? i : i + max_idx;
this->col = Derived::IsRowMajor ? j + max_idx : j;
}
};
// Suppress NaN. The only case in which we return NaN is if the matrix is all NaN,
// in which case, row=0, col=0 is returned for the location.
template <typename Derived, bool is_min>
struct minmax_coeff_visitor<Derived, is_min, PropagateNumbers, false> : coeff_visitor<Derived> {
typedef typename Derived::Scalar Scalar;
using Packet = typename packet_traits<Scalar>::type;
using Comparator = minmax_compare<Scalar, PropagateNumbers, is_min>;
EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index i, Index j) {
if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
this->res = value;
this->row = i;
this->col = j;
}
}
EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index i, Index j) {
const Index PacketSize = packet_traits<Scalar>::size;
Scalar value = Comparator::predux(p);
if ((!(numext::isnan)(value) && (numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
const Packet range = preverse(plset<Packet>(Scalar(1)));
/* mask will be zero for NaNs, so they will be ignored. */
Packet mask = pcmp_eq(pset1<Packet>(value), p);
Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
this->res = value;
this->row = Derived::IsRowMajor ? i : i + max_idx;
this->col = Derived::IsRowMajor ? j + max_idx : j;
}
}
EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) {
const Index PacketSize = packet_traits<Scalar>::size;
Scalar value = Comparator::predux(p);
if ((numext::isnan)(value)) {
this->res = value;
this->row = 0;
this->col = 0;
return;
}
const Packet range = preverse(plset<Packet>(Scalar(1)));
/* mask will be zero for NaNs, so they will be ignored. */
Packet mask = pcmp_eq(pset1<Packet>(value), p);
Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
this->res = value;
this->row = Derived::IsRowMajor ? i : i + max_idx;
this->col = Derived::IsRowMajor ? j + max_idx : j;
}
};
// Propagate NaNs. If the matrix contains NaN, the location of the first NaN
// will be returned in row and col.
template <typename Derived, bool is_min, int NaNPropagation>
struct minmax_coeff_visitor<Derived, is_min, NaNPropagation, false> : coeff_visitor<Derived> {
typedef typename Derived::Scalar Scalar;
using Packet = typename packet_traits<Scalar>::type;
using Comparator = minmax_compare<Scalar, PropagateNaN, is_min>;
EIGEN_DEVICE_FUNC inline void operator()(const Scalar& value, Index i, Index j) {
const bool value_is_nan = (numext::isnan)(value);
if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
this->res = value;
this->row = i;
this->col = j;
}
}
EIGEN_DEVICE_FUNC inline void packet(const Packet& p, Index i, Index j) {
const Index PacketSize = packet_traits<Scalar>::size;
Scalar value = Comparator::predux(p);
const bool value_is_nan = (numext::isnan)(value);
if ((value_is_nan && !(numext::isnan)(this->res)) || Comparator::compare(value, this->res)) {
const Packet range = preverse(plset<Packet>(Scalar(1)));
// If the value is NaN, pick the first position of a NaN, otherwise pick the first extremal value.
Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
this->res = value;
this->row = Derived::IsRowMajor ? i : i + max_idx;
this->col = Derived::IsRowMajor ? j + max_idx : j;
}
}
EIGEN_DEVICE_FUNC inline void initpacket(const Packet& p, Index i, Index j) {
const Index PacketSize = packet_traits<Scalar>::size;
Scalar value = Comparator::predux(p);
const bool value_is_nan = (numext::isnan)(value);
const Packet range = preverse(plset<Packet>(Scalar(1)));
// If the value is NaN, pick the first position of a NaN, otherwise pick the first extremal value.
Packet mask = value_is_nan ? pnot(pcmp_eq(p, p)) : pcmp_eq(pset1<Packet>(value), p);
Index max_idx = PacketSize - static_cast<Index>(predux_max(pand(range, mask)));
this->res = value;
this->row = Derived::IsRowMajor ? i : i + max_idx;
this->col = Derived::IsRowMajor ? j + max_idx : j;
}
};
template <typename Derived, bool is_min, int NaNPropagation>
struct functor_traits<minmax_coeff_visitor<Derived, is_min, NaNPropagation>> {
using Scalar = typename Derived::Scalar;
enum { Cost = NumTraits<Scalar>::AddCost, LinearAccess = false, PacketAccess = packet_traits<Scalar>::HasCmp };
};
template <typename Scalar>
struct all_visitor {
using result_type = bool;
@@ -643,100 +476,6 @@ struct all_finite_impl<Derived, false> {
} // end namespace internal
/** \fn DenseBase<Derived>::minCoeff(IndexType* rowId, IndexType* colId) const
* \returns the minimum of all coefficients of *this and puts in *row and *col its location.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(Index*), DenseBase::maxCoeff(Index*,Index*), DenseBase::visit(), DenseBase::minCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::minCoeff(IndexType* rowId,
IndexType* colId) const {
eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
this->visit(minVisitor);
*rowId = minVisitor.row;
if (colId) *colId = minVisitor.col;
return minVisitor.res;
}
/** \returns the minimum of all coefficients of *this and puts in *index its location.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::visit(),
* DenseBase::minCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::minCoeff(IndexType* index) const {
eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
internal::minmax_coeff_visitor<Derived, true, NaNPropagation> minVisitor;
this->visit(minVisitor);
*index = IndexType((RowsAtCompileTime == 1) ? minVisitor.col : minVisitor.row);
return minVisitor.res;
}
/** \fn DenseBase<Derived>::maxCoeff(IndexType* rowId, IndexType* colId) const
* \returns the maximum of all coefficients of *this and puts in *row and *col its location.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visit(), DenseBase::maxCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::maxCoeff(IndexType* rowPtr,
IndexType* colPtr) const {
eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
this->visit(maxVisitor);
*rowPtr = maxVisitor.row;
if (colPtr) *colPtr = maxVisitor.col;
return maxVisitor.res;
}
/** \returns the maximum of all coefficients of *this and puts in *index its location.
*
* In case \c *this contains NaN, NaNPropagation determines the behavior:
* NaNPropagation == PropagateFast : undefined
* NaNPropagation == PropagateNaN : result is NaN
* NaNPropagation == PropagateNumbers : result is maximum of elements that are not NaN
* \warning the matrix must be not empty, otherwise an assertion is triggered.
*
* \sa DenseBase::maxCoeff(IndexType*,IndexType*), DenseBase::minCoeff(IndexType*,IndexType*), DenseBase::visitor(),
* DenseBase::maxCoeff()
*/
template <typename Derived>
template <int NaNPropagation, typename IndexType>
EIGEN_DEVICE_FUNC typename internal::traits<Derived>::Scalar DenseBase<Derived>::maxCoeff(IndexType* index) const {
eigen_assert(this->rows() > 0 && this->cols() > 0 && "you are using an empty matrix");
EIGEN_STATIC_ASSERT_VECTOR_ONLY(Derived)
internal::minmax_coeff_visitor<Derived, false, NaNPropagation> maxVisitor;
this->visit(maxVisitor);
*index = (RowsAtCompileTime == 1) ? maxVisitor.col : maxVisitor.row;
return maxVisitor.res;
}
/** \returns true if all coefficients are true
*
* Example: \include MatrixBase_all.cpp

View File

@@ -118,6 +118,7 @@ struct packet_traits<float> : default_packet_traits {
HasLog1p = 1,
HasExpm1 = 1,
HasExp = 1,
HasPow = 1,
HasNdtri = 1,
HasBessel = 1,
HasSqrt = 1,
@@ -149,6 +150,7 @@ struct packet_traits<double> : default_packet_traits {
HasErf = 1,
HasErfc = 1,
HasExp = 1,
HasPow = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
@@ -654,25 +656,6 @@ template <>
EIGEN_STRONG_INLINE uint64_t pfirst<Packet4ul>(const Packet4ul& a) {
return _mm_extract_epi64_0(_mm256_castsi256_si128(a));
}
template <>
EIGEN_STRONG_INLINE int64_t predux<Packet4l>(const Packet4l& a) {
__m128i r = _mm_add_epi64(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
return _mm_extract_epi64_0(r) + _mm_extract_epi64_1(r);
}
template <>
EIGEN_STRONG_INLINE uint64_t predux<Packet4ul>(const Packet4ul& a) {
__m128i r = _mm_add_epi64(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
return numext::bit_cast<uint64_t>(_mm_extract_epi64_0(r) + _mm_extract_epi64_1(r));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4l& a) {
return _mm256_movemask_pd(_mm256_castsi256_pd(a)) != 0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4ul& a) {
return _mm256_movemask_pd(_mm256_castsi256_pd(a)) != 0;
}
#define MM256_SHUFFLE_EPI64(A, B, M) _mm256_shuffle_pd(_mm256_castsi256_pd(A), _mm256_castsi256_pd(B), M)
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet4l, 4>& kernel) {
@@ -1955,23 +1938,6 @@ EIGEN_STRONG_INLINE Packet4d pldexp_fast<Packet4d>(const Packet4d& a, const Pack
return pmul(a, c); // a * 2^e
}
template <>
EIGEN_STRONG_INLINE float predux<Packet8f>(const Packet8f& a) {
return predux(Packet4f(_mm_add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1))));
}
template <>
EIGEN_STRONG_INLINE double predux<Packet4d>(const Packet4d& a) {
return predux(Packet2d(_mm_add_pd(_mm256_castpd256_pd128(a), _mm256_extractf128_pd(a, 1))));
}
template <>
EIGEN_STRONG_INLINE int predux<Packet8i>(const Packet8i& a) {
return predux(Packet4i(_mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1))));
}
template <>
EIGEN_STRONG_INLINE uint32_t predux<Packet8ui>(const Packet8ui& a) {
return predux(Packet4ui(_mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1))));
}
template <>
EIGEN_STRONG_INLINE Packet4f predux_half_dowto4<Packet8f>(const Packet8f& a) {
return _mm_add_ps(_mm256_castps256_ps128(a), _mm256_extractf128_ps(a, 1));
@@ -1985,82 +1951,6 @@ EIGEN_STRONG_INLINE Packet4ui predux_half_dowto4<Packet8ui>(const Packet8ui& a)
return _mm_add_epi32(_mm256_castsi256_si128(a), _mm256_extractf128_si256(a, 1));
}
template <>
EIGEN_STRONG_INLINE float predux_mul<Packet8f>(const Packet8f& a) {
Packet8f tmp;
tmp = _mm256_mul_ps(a, _mm256_permute2f128_ps(a, a, 1));
tmp = _mm256_mul_ps(tmp, _mm256_shuffle_ps(tmp, tmp, _MM_SHUFFLE(1, 0, 3, 2)));
return pfirst(_mm256_mul_ps(tmp, _mm256_shuffle_ps(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE double predux_mul<Packet4d>(const Packet4d& a) {
Packet4d tmp;
tmp = _mm256_mul_pd(a, _mm256_permute2f128_pd(a, a, 1));
return pfirst(_mm256_mul_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE float predux_min<Packet8f>(const Packet8f& a) {
Packet8f tmp = _mm256_min_ps(a, _mm256_permute2f128_ps(a, a, 1));
tmp = _mm256_min_ps(tmp, _mm256_shuffle_ps(tmp, tmp, _MM_SHUFFLE(1, 0, 3, 2)));
return pfirst(_mm256_min_ps(tmp, _mm256_shuffle_ps(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE double predux_min<Packet4d>(const Packet4d& a) {
Packet4d tmp = _mm256_min_pd(a, _mm256_permute2f128_pd(a, a, 1));
return pfirst(_mm256_min_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE float predux_max<Packet8f>(const Packet8f& a) {
Packet8f tmp = _mm256_max_ps(a, _mm256_permute2f128_ps(a, a, 1));
tmp = _mm256_max_ps(tmp, _mm256_shuffle_ps(tmp, tmp, _MM_SHUFFLE(1, 0, 3, 2)));
return pfirst(_mm256_max_ps(tmp, _mm256_shuffle_ps(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE double predux_max<Packet4d>(const Packet4d& a) {
Packet4d tmp = _mm256_max_pd(a, _mm256_permute2f128_pd(a, a, 1));
return pfirst(_mm256_max_pd(tmp, _mm256_shuffle_pd(tmp, tmp, 1)));
}
// not needed yet
// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet8f& x)
// {
// return _mm256_movemask_ps(x)==0xFF;
// }
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8f& x) {
return _mm256_movemask_ps(x) != 0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4d& x) {
return _mm256_movemask_pd(x) != 0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8i& x) {
return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8ui& x) {
return _mm256_movemask_ps(_mm256_castsi256_ps(x)) != 0;
}
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8h& x) {
return _mm_movemask_epi8(x) != 0;
}
#endif // EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8bf& x) {
return _mm_movemask_epi8(x) != 0;
}
EIGEN_DEVICE_FUNC inline void ptranspose(PacketBlock<Packet8f, 8>& kernel) {
__m256 T0 = _mm256_unpacklo_ps(kernel.packet[0], kernel.packet[1]);
__m256 T1 = _mm256_unpackhi_ps(kernel.packet[0], kernel.packet[1]);
@@ -2361,24 +2251,64 @@ EIGEN_STRONG_INLINE Packet8h ptrunc<Packet8h>(const Packet8h& a) {
return float2half(ptrunc<Packet8f>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE Packet8h pisinf<Packet8h>(const Packet8h& a) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
return _mm_cmpeq_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf));
}
template <>
EIGEN_STRONG_INLINE Packet8h pisnan<Packet8h>(const Packet8h& a) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
return _mm_cmpgt_epi16(_mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask)), _mm_set1_epi16(kInf));
}
// convert the sign-magnitude representation to two's complement
EIGEN_STRONG_INLINE __m128i pmaptosigned(const __m128i& a) {
constexpr uint16_t kAbsMask = (1 << 15) - 1;
// if 'a' has the sign bit set, clear the sign bit and negate the result as if it were an integer
return _mm_sign_epi16(_mm_and_si128(a, _mm_set1_epi16(kAbsMask)), a);
}
// return true if both `a` and `b` are not NaN
EIGEN_STRONG_INLINE Packet8h pisordered(const Packet8h& a, const Packet8h& b) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
__m128i abs_a = _mm_and_si128(a.m_val, _mm_set1_epi16(kAbsMask));
__m128i abs_b = _mm_and_si128(b.m_val, _mm_set1_epi16(kAbsMask));
// check if both `abs_a <= kInf` and `abs_b <= kInf` by checking if max(abs_a, abs_b) <= kInf
// SSE has no `lesser or equal` instruction for integers, but comparing against kInf + 1 accomplishes the same goal
return _mm_cmplt_epi16(_mm_max_epu16(abs_a, abs_b), _mm_set1_epi16(kInf + 1));
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_eq(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_eq(half2float(a), half2float(b)));
__m128i isOrdered = pisordered(a, b);
__m128i isEqual = _mm_cmpeq_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_and_si128(isOrdered, isEqual);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_le(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_le(half2float(a), half2float(b)));
__m128i isOrdered = pisordered(a, b);
__m128i isGreater = _mm_cmpgt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_andnot_si128(isGreater, isOrdered);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_lt(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_lt(half2float(a), half2float(b)));
__m128i isOrdered = pisordered(a, b);
__m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_and_si128(isOrdered, isLess);
}
template <>
EIGEN_STRONG_INLINE Packet8h pcmp_lt_or_nan(const Packet8h& a, const Packet8h& b) {
return Pack16To8(pcmp_lt_or_nan(half2float(a), half2float(b)));
__m128i isUnordered = por(pisnan(a), pisnan(b));
__m128i isLess = _mm_cmplt_epi16(pmaptosigned(a.m_val), pmaptosigned(b.m_val));
return _mm_or_si128(isUnordered, isLess);
}
template <>
@@ -2473,34 +2403,6 @@ EIGEN_STRONG_INLINE void pscatter<Eigen::half, Packet8h>(Eigen::half* to, const
to[stride * 7] = aux[7];
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux<Packet8h>(const Packet8h& a) {
Packet8f af = half2float(a);
float reduced = predux<Packet8f>(af);
return Eigen::half(reduced);
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_max<Packet8h>(const Packet8h& a) {
Packet8f af = half2float(a);
float reduced = predux_max<Packet8f>(af);
return Eigen::half(reduced);
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_min<Packet8h>(const Packet8h& a) {
Packet8f af = half2float(a);
float reduced = predux_min<Packet8f>(af);
return Eigen::half(reduced);
}
template <>
EIGEN_STRONG_INLINE Eigen::half predux_mul<Packet8h>(const Packet8h& a) {
Packet8f af = half2float(a);
float reduced = predux_mul<Packet8f>(af);
return Eigen::half(reduced);
}
template <>
EIGEN_STRONG_INLINE Packet8h preverse(const Packet8h& a) {
__m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);
@@ -2859,26 +2761,6 @@ EIGEN_STRONG_INLINE void pscatter<bfloat16, Packet8bf>(bfloat16* to, const Packe
to[stride * 7] = aux[7];
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux<Packet8bf>(const Packet8bf& a) {
return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<Packet8bf>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<Packet8bf>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_min<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_mul<Packet8bf>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE Packet8bf preverse(const Packet8bf& a) {
__m128i m = _mm_setr_epi8(14, 15, 12, 13, 10, 11, 8, 9, 6, 7, 4, 5, 2, 3, 0, 1);

View File

@@ -0,0 +1,353 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 Charlie Schlosser <cs.schlosser@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_REDUCTIONS_AVX_H
#define EIGEN_REDUCTIONS_AVX_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8i -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE int predux(const Packet8i& a) {
Packet4i lo = _mm256_castsi256_si128(a);
Packet4i hi = _mm256_extractf128_si256(a, 1);
return predux(padd(lo, hi));
}
template <>
EIGEN_STRONG_INLINE int predux_mul(const Packet8i& a) {
Packet4i lo = _mm256_castsi256_si128(a);
Packet4i hi = _mm256_extractf128_si256(a, 1);
return predux_mul(pmul(lo, hi));
}
template <>
EIGEN_STRONG_INLINE int predux_min(const Packet8i& a) {
Packet4i lo = _mm256_castsi256_si128(a);
Packet4i hi = _mm256_extractf128_si256(a, 1);
return predux_min(pmin(lo, hi));
}
template <>
EIGEN_STRONG_INLINE int predux_max(const Packet8i& a) {
Packet4i lo = _mm256_castsi256_si128(a);
Packet4i hi = _mm256_extractf128_si256(a, 1);
return predux_max(pmax(lo, hi));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8i& a) {
#ifdef EIGEN_VECTORIZE_AVX2
return _mm256_movemask_epi8(a) != 0x0;
#else
return _mm256_movemask_ps(_mm256_castsi256_ps(a)) != 0x0;
#endif
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8ui -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE uint32_t predux(const Packet8ui& a) {
Packet4ui lo = _mm256_castsi256_si128(a);
Packet4ui hi = _mm256_extractf128_si256(a, 1);
return predux(padd(lo, hi));
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_mul(const Packet8ui& a) {
Packet4ui lo = _mm256_castsi256_si128(a);
Packet4ui hi = _mm256_extractf128_si256(a, 1);
return predux_mul(pmul(lo, hi));
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_min(const Packet8ui& a) {
Packet4ui lo = _mm256_castsi256_si128(a);
Packet4ui hi = _mm256_extractf128_si256(a, 1);
return predux_min(pmin(lo, hi));
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_max(const Packet8ui& a) {
Packet4ui lo = _mm256_castsi256_si128(a);
Packet4ui hi = _mm256_extractf128_si256(a, 1);
return predux_max(pmax(lo, hi));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8ui& a) {
#ifdef EIGEN_VECTORIZE_AVX2
return _mm256_movemask_epi8(a) != 0x0;
#else
return _mm256_movemask_ps(_mm256_castsi256_ps(a)) != 0x0;
#endif
}
#ifdef EIGEN_VECTORIZE_AVX2
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4l -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE int64_t predux(const Packet4l& a) {
Packet2l lo = _mm256_castsi256_si128(a);
Packet2l hi = _mm256_extractf128_si256(a, 1);
return predux(padd(lo, hi));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4l& a) {
return _mm256_movemask_pd(_mm256_castsi256_pd(a)) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4ul -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE uint64_t predux(const Packet4ul& a) {
return static_cast<uint64_t>(predux(Packet4l(a)));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4ul& a) {
return _mm256_movemask_pd(_mm256_castsi256_pd(a)) != 0x0;
}
#endif
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8f -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE float predux(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux(padd(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_mul(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux_mul(pmul(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_min(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux_min(pmin(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateNumbers>(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux_min<PropagateNumbers>(pmin<PropagateNumbers>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateNaN>(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux_min<PropagateNaN>(pmin<PropagateNaN>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_max(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux_max(pmax(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateNumbers>(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux_max<PropagateNumbers>(pmax<PropagateNumbers>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateNaN>(const Packet8f& a) {
Packet4f lo = _mm256_castps256_ps128(a);
Packet4f hi = _mm256_extractf128_ps(a, 1);
return predux_max<PropagateNaN>(pmax<PropagateNaN>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8f& a) {
return _mm256_movemask_ps(a) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4d -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE double predux(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux(padd(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_mul(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux_mul(pmul(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_min(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux_min(pmin(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateNumbers>(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux_min<PropagateNumbers>(pmin<PropagateNumbers>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateNaN>(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux_min<PropagateNaN>(pmin<PropagateNaN>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_max(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux_max(pmax(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateNumbers>(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux_max<PropagateNumbers>(pmax<PropagateNumbers>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateNaN>(const Packet4d& a) {
Packet2d lo = _mm256_castpd256_pd128(a);
Packet2d hi = _mm256_extractf128_pd(a, 1);
return predux_max<PropagateNaN>(pmax<PropagateNaN>(lo, hi));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4d& a) {
return _mm256_movemask_pd(a) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8h -- -- -- -- -- -- -- -- -- -- -- -- */
#ifndef EIGEN_VECTORIZE_AVX512FP16
template <>
EIGEN_STRONG_INLINE half predux(const Packet8h& a) {
return static_cast<half>(predux(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_mul(const Packet8h& a) {
return static_cast<half>(predux_mul(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_min(const Packet8h& a) {
return static_cast<half>(predux_min(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_min<PropagateNumbers>(const Packet8h& a) {
return static_cast<half>(predux_min<PropagateNumbers>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_min<PropagateNaN>(const Packet8h& a) {
return static_cast<half>(predux_min<PropagateNaN>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_max(const Packet8h& a) {
return static_cast<half>(predux_max(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_max<PropagateNumbers>(const Packet8h& a) {
return static_cast<half>(predux_max<PropagateNumbers>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE half predux_max<PropagateNaN>(const Packet8h& a) {
return static_cast<half>(predux_max<PropagateNaN>(half2float(a)));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8h& a) {
return _mm_movemask_epi8(a) != 0;
}
#endif // EIGEN_VECTORIZE_AVX512FP16
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet8bf -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE bfloat16 predux(const Packet8bf& a) {
return static_cast<bfloat16>(predux<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_mul(const Packet8bf& a) {
return static_cast<bfloat16>(predux_mul<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min(const Packet8bf& a) {
return static_cast<bfloat16>(predux_min(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNumbers>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_min<PropagateNumbers>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_min<PropagateNaN>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_min<PropagateNaN>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max(const Packet8bf& a) {
return static_cast<bfloat16>(predux_max<Packet8f>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNumbers>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_max<PropagateNumbers>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bfloat16 predux_max<PropagateNaN>(const Packet8bf& a) {
return static_cast<bfloat16>(predux_max<PropagateNaN>(Bf16ToF32(a)));
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet8bf& a) {
return _mm_movemask_epi8(a) != 0;
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_REDUCTIONS_AVX_H

View File

@@ -793,6 +793,12 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bfloat16 nextafter(const bfloat16& from, c
return numext::bit_cast<bfloat16>(from_bits);
}
// Specialize multiply-add to match packet operations and reduce conversions to/from float.
template<>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::bfloat16 madd<Eigen::bfloat16>(const Eigen::bfloat16& x, const Eigen::bfloat16& y, const Eigen::bfloat16& z) {
return Eigen::bfloat16(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
}
} // namespace numext
} // namespace Eigen

View File

@@ -1689,7 +1689,8 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet phypot_complex(const
}
template <typename Packet>
struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
struct psign_impl<Packet, std::enable_if_t<!is_scalar<Packet>::value &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
using Scalar = typename unpacket_traits<Packet>::type;
@@ -1705,7 +1706,8 @@ struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<P
};
template <typename Packet>
struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
struct psign_impl<Packet, std::enable_if_t<!is_scalar<Packet>::value &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
@@ -1724,7 +1726,8 @@ struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<P
};
template <typename Packet>
struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
struct psign_impl<Packet, std::enable_if_t<!is_scalar<Packet>::value &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
!NumTraits<typename unpacket_traits<Packet>::type>::IsSigned &&
NumTraits<typename unpacket_traits<Packet>::type>::IsInteger>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
@@ -1739,7 +1742,8 @@ struct psign_impl<Packet, std::enable_if_t<!NumTraits<typename unpacket_traits<P
// \internal \returns the the sign of a complex number z, defined as z / abs(z).
template <typename Packet>
struct psign_impl<Packet, std::enable_if_t<NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
struct psign_impl<Packet, std::enable_if_t<!is_scalar<Packet>::value &&
NumTraits<typename unpacket_traits<Packet>::type>::IsComplex &&
unpacket_traits<Packet>::vectorizable>> {
static EIGEN_DEVICE_FUNC inline Packet run(const Packet& a) {
typedef typename unpacket_traits<Packet>::type Scalar;
@@ -2176,7 +2180,8 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet generic_pow_impl(const Packet& x, c
// Generic implementation of pow(x,y).
template <typename Packet>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Packet& x, const Packet& y) {
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t<!is_scalar<Packet>::value, Packet> generic_pow(
const Packet& x, const Packet& y) {
typedef typename unpacket_traits<Packet>::type Scalar;
const Packet cst_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
@@ -2266,6 +2271,12 @@ EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS Packet generic_pow(const Pac
return pow;
}
template <typename Scalar>
EIGEN_DEFINE_FUNCTION_ALLOWING_MULTIPLE_DEFINITIONS std::enable_if_t<is_scalar<Scalar>::value, Scalar> generic_pow(
const Scalar& x, const Scalar& y) {
return numext::pow(x, y);
}
namespace unary_pow {
template <typename ScalarExponent, bool IsInteger = NumTraits<ScalarExponent>::IsInteger>
@@ -2347,35 +2358,36 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet int_pow(const Packet& x, const Scal
}
template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet gen_pow(const Packet& x,
const typename unpacket_traits<Packet>::type& exponent) {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<!is_scalar<Packet>::value, Packet> gen_pow(
const Packet& x, const typename unpacket_traits<Packet>::type& exponent) {
const Packet exponent_packet = pset1<Packet>(exponent);
return generic_pow_impl(x, exponent_packet);
}
template <typename Scalar>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE std::enable_if_t<is_scalar<Scalar>::value, Scalar> gen_pow(
const Scalar& x, const Scalar& exponent) {
return numext::pow(x, exponent);
}
template <typename Packet, typename ScalarExponent>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_nonint_nonint_errors(const Packet& x, const Packet& powx,
const ScalarExponent& exponent) {
using Scalar = typename unpacket_traits<Packet>::type;
// non-integer base and exponent case
const Scalar pos_zero = Scalar(0);
const Scalar all_ones = ptrue<Scalar>(Scalar());
const Scalar pos_one = Scalar(1);
const Scalar pos_inf = NumTraits<Scalar>::infinity();
const Packet cst_pos_zero = pzero(x);
const Packet cst_pos_one = pset1<Packet>(pos_one);
const Packet cst_pos_inf = pset1<Packet>(pos_inf);
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
const Packet cst_pos_inf = pset1<Packet>(NumTraits<Scalar>::infinity());
const Packet cst_true = ptrue<Packet>(x);
const bool exponent_is_not_fin = !(numext::isfinite)(exponent);
const bool exponent_is_neg = exponent < ScalarExponent(0);
const bool exponent_is_pos = exponent > ScalarExponent(0);
const Packet exp_is_not_fin = pset1<Packet>(exponent_is_not_fin ? all_ones : pos_zero);
const Packet exp_is_neg = pset1<Packet>(exponent_is_neg ? all_ones : pos_zero);
const Packet exp_is_pos = pset1<Packet>(exponent_is_pos ? all_ones : pos_zero);
const Packet exp_is_not_fin = exponent_is_not_fin ? cst_true : cst_pos_zero;
const Packet exp_is_neg = exponent_is_neg ? cst_true : cst_pos_zero;
const Packet exp_is_pos = exponent_is_pos ? cst_true : cst_pos_zero;
const Packet exp_is_inf = pand(exp_is_not_fin, por(exp_is_neg, exp_is_pos));
const Packet exp_is_nan = pandnot(exp_is_not_fin, por(exp_is_neg, exp_is_pos));
@@ -2411,22 +2423,15 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet handle_negative_exponent(const Pack
// This routine handles negative exponents.
// The return value is either 0, 1, or -1.
const Scalar pos_zero = Scalar(0);
const Scalar all_ones = ptrue<Scalar>(Scalar());
const Scalar pos_one = Scalar(1);
const Packet cst_pos_one = pset1<Packet>(pos_one);
const Packet cst_pos_one = pset1<Packet>(Scalar(1));
const bool exponent_is_odd = exponent % ScalarExponent(2) != ScalarExponent(0);
const Packet exp_is_odd = pset1<Packet>(exponent_is_odd ? all_ones : pos_zero);
const Packet exp_is_odd = exponent_is_odd ? ptrue<Packet>(x) : pzero<Packet>(x);
const Packet abs_x = pabs(x);
const Packet abs_x_is_one = pcmp_eq(abs_x, cst_pos_one);
Packet result = pselect(exp_is_odd, x, abs_x);
result = pand(abs_x_is_one, result);
result = pselect(abs_x_is_one, result, pzero<Packet>(x));
return result;
}

View File

@@ -497,16 +497,56 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half& operator/=(half& a, const half& b) {
a = half(float(a) / float(b));
return a;
}
// Non-negative floating point numbers have a monotonic mapping to non-negative integers.
// This property allows floating point numbers to be reinterpreted as integers for comparisons, which is useful if there
// is no native floating point comparison operator. Floating point signedness is handled by the sign-magnitude
// representation, whereas integers typically use two's complement. Converting the bit pattern from sign-magnitude to
// two's complement allows the transformed bit patterns be compared as signed integers. All edge cases (+/-0 and +/-
// infinity) are handled automatically, except NaN.
//
// fp16 uses 1 sign bit, 5 exponent bits, and 10 mantissa bits. The bit pattern conveys NaN when all the exponent
// bits (5) are set, and at least one mantissa bit is set. The sign bit is irrelevant for determining NaN. To check for
// NaN, clear the sign bit and check if the integral representation is greater than 01111100000000. To test
// for non-NaN, clear the sign bit and check if the integeral representation is less than or equal to 01111100000000.
// convert sign-magnitude representation to two's complement
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC int16_t mapToSigned(uint16_t a) {
constexpr uint16_t kAbsMask = (1 << 15) - 1;
// If the sign bit is set, clear the sign bit and return the (integer) negation. Otherwise, return the input.
return (a >> 15) ? -(a & kAbsMask) : a;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool isOrdered(const half& a, const half& b) {
constexpr uint16_t kInf = ((1 << 5) - 1) << 10;
constexpr uint16_t kAbsMask = (1 << 15) - 1;
return numext::maxi(a.x & kAbsMask, b.x & kAbsMask) <= kInf;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator==(const half& a, const half& b) {
return numext::equal_strict(float(a), float(b));
bool result = mapToSigned(a.x) == mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) {
return numext::not_equal_strict(float(a), float(b));
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator!=(const half& a, const half& b) { return !(a == b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) {
bool result = mapToSigned(a.x) < mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) {
bool result = mapToSigned(a.x) <= mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) {
bool result = mapToSigned(a.x) > mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) {
bool result = mapToSigned(a.x) >= mapToSigned(b.x);
result &= isOrdered(a, b);
return result;
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<(const half& a, const half& b) { return float(a) < float(b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator<=(const half& a, const half& b) { return float(a) <= float(b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>(const half& a, const half& b) { return float(a) > float(b); }
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool operator>=(const half& a, const half& b) { return float(a) >= float(b); }
#if EIGEN_COMP_CLANG && defined(EIGEN_GPUCC)
#pragma pop_macro("EIGEN_DEVICE_FUNC")
@@ -706,7 +746,11 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isnan)(const half& a) {
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC bool(isfinite)(const half& a) {
return !(isinf EIGEN_NOT_A_MACRO(a)) && !(isnan EIGEN_NOT_A_MACRO(a));
#if defined(EIGEN_HAS_ARM64_FP16_SCALAR_ARITHMETIC) || defined(EIGEN_HAS_BUILTIN_FLOAT16)
return (numext::bit_cast<numext::uint16_t>(a.x) & 0x7fff) < 0x7c00;
#else
return (a.x & 0x7fff) < 0x7c00;
#endif
}
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC half abs(const half& a) {
@@ -911,6 +955,12 @@ EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC uint16_t bit_cast<uint16_t, Eigen::half>(c
return Eigen::half_impl::raw_half_as_uint16(src);
}
// Specialize multiply-add to match packet operations and reduce conversions to/from float.
template<>
EIGEN_STRONG_INLINE EIGEN_DEVICE_FUNC Eigen::half madd<Eigen::half>(const Eigen::half& x, const Eigen::half& y, const Eigen::half& z) {
return Eigen::half(static_cast<float>(x) * static_cast<float>(y) + static_cast<float>(z));
}
} // namespace numext
} // namespace Eigen

View File

@@ -73,30 +73,13 @@ struct packet_traits<std::complex<float> > : default_packet_traits {
};
template <>
struct unpacket_traits<Packet1cf> {
typedef std::complex<float> type;
typedef Packet1cf half;
typedef Packet2f as_real;
enum {
size = 1,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet1cf> : neon_unpacket_default<Packet1cf, std::complex<float>> {
using as_real = Packet2f;
};
template <>
struct unpacket_traits<Packet2cf> {
typedef std::complex<float> type;
typedef Packet1cf half;
typedef Packet4f as_real;
enum {
size = 2,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet2cf> : neon_unpacket_default<Packet2cf, std::complex<float>> {
using half = Packet1cf;
using as_real = Packet4f;
};
template <>
@@ -297,10 +280,12 @@ EIGEN_STRONG_INLINE Packet2cf pandnot<Packet2cf>(const Packet2cf& a, const Packe
template <>
EIGEN_STRONG_INLINE Packet1cf pload<Packet1cf>(const std::complex<float>* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet1cf>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return Packet1cf(pload<Packet2f>((const float*)from));
}
template <>
EIGEN_STRONG_INLINE Packet2cf pload<Packet2cf>(const std::complex<float>* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet2cf>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return Packet2cf(pload<Packet4f>(reinterpret_cast<const float*>(from)));
}
@@ -324,10 +309,12 @@ EIGEN_STRONG_INLINE Packet2cf ploaddup<Packet2cf>(const std::complex<float>* fro
template <>
EIGEN_STRONG_INLINE void pstore<std::complex<float> >(std::complex<float>* to, const Packet1cf& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet1cf>::alignment);
EIGEN_DEBUG_ALIGNED_STORE pstore((float*)to, from.v);
}
template <>
EIGEN_STRONG_INLINE void pstore<std::complex<float> >(std::complex<float>* to, const Packet2cf& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet2cf>::alignment);
EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast<float*>(to), from.v);
}
@@ -538,21 +525,13 @@ struct packet_traits<std::complex<double> > : default_packet_traits {
};
template <>
struct unpacket_traits<Packet1cd> {
typedef std::complex<double> type;
typedef Packet1cd half;
typedef Packet2d as_real;
enum {
size = 1,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet1cd> : neon_unpacket_default<Packet1cd, std::complex<double>> {
using as_real = Packet2d;
};
template <>
EIGEN_STRONG_INLINE Packet1cd pload<Packet1cd>(const std::complex<double>* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet1cd>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return Packet1cd(pload<Packet2d>(reinterpret_cast<const double*>(from)));
}
@@ -666,6 +645,7 @@ EIGEN_STRONG_INLINE Packet1cd ploaddup<Packet1cd>(const std::complex<double>* fr
template <>
EIGEN_STRONG_INLINE void pstore<std::complex<double> >(std::complex<double>* to, const Packet1cd& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet1cd>::alignment);
EIGEN_DEBUG_ALIGNED_STORE pstore(reinterpret_cast<double*>(to), from.v);
}

View File

@@ -205,6 +205,7 @@ struct packet_traits<float> : default_packet_traits {
HasATanh = 1,
HasLog = 1,
HasExp = 1,
HasPow = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
@@ -437,224 +438,74 @@ struct packet_traits<uint64_t> : default_packet_traits {
};
};
template <typename Packet, typename Scalar>
struct neon_unpacket_default {
using type = Scalar;
using half = Packet;
static constexpr int size = sizeof(Packet) / sizeof(Scalar);
static constexpr int alignment = sizeof(Packet);
static constexpr bool vectorizable = true;
static constexpr bool masked_load_available = false;
static constexpr bool masked_store_available = false;
};
template <>
struct unpacket_traits<Packet2f> {
typedef float type;
typedef Packet2f half;
typedef Packet2i integer_packet;
enum {
size = 2,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet2f> : neon_unpacket_default<Packet2f, float> {
using integer_packet = Packet2i;
};
template <>
struct unpacket_traits<Packet4f> {
typedef float type;
typedef Packet2f half;
typedef Packet4i integer_packet;
enum {
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet4f> : neon_unpacket_default<Packet4f, float> {
using half = Packet2f;
using integer_packet = Packet4i;
};
template <>
struct unpacket_traits<Packet4c> {
typedef int8_t type;
typedef Packet4c half;
enum {
size = 4,
alignment = Unaligned,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet4c> : neon_unpacket_default<Packet4c, int8_t> {};
template <>
struct unpacket_traits<Packet8c> : neon_unpacket_default<Packet8c, int8_t> {
using half = Packet4c;
};
template <>
struct unpacket_traits<Packet8c> {
typedef int8_t type;
typedef Packet4c half;
enum {
size = 8,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet16c> : neon_unpacket_default<Packet16c, int8_t> {
using half = Packet8c;
};
template <>
struct unpacket_traits<Packet16c> {
typedef int8_t type;
typedef Packet8c half;
enum {
size = 16,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet4uc> : neon_unpacket_default<Packet4uc, uint8_t> {};
template <>
struct unpacket_traits<Packet8uc> : neon_unpacket_default<Packet8uc, uint8_t> {
using half = Packet4uc;
};
template <>
struct unpacket_traits<Packet4uc> {
typedef uint8_t type;
typedef Packet4uc half;
enum {
size = 4,
alignment = Unaligned,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet16uc> : neon_unpacket_default<Packet16uc, uint8_t> {
using half = Packet8uc;
};
template <>
struct unpacket_traits<Packet8uc> {
typedef uint8_t type;
typedef Packet4uc half;
enum {
size = 8,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet4s> : neon_unpacket_default<Packet4s, int16_t> {};
template <>
struct unpacket_traits<Packet8s> : neon_unpacket_default<Packet8s, int16_t> {
using half = Packet4s;
};
template <>
struct unpacket_traits<Packet16uc> {
typedef uint8_t type;
typedef Packet8uc half;
enum {
size = 16,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet4us> : neon_unpacket_default<Packet4us, uint16_t> {};
template <>
struct unpacket_traits<Packet8us> : neon_unpacket_default<Packet8us, uint16_t> {
using half = Packet4us;
};
template <>
struct unpacket_traits<Packet4s> {
typedef int16_t type;
typedef Packet4s half;
enum {
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet2i> : neon_unpacket_default<Packet2i, int32_t> {};
template <>
struct unpacket_traits<Packet4i> : neon_unpacket_default<Packet4i, int32_t> {
using half = Packet2i;
};
template <>
struct unpacket_traits<Packet8s> {
typedef int16_t type;
typedef Packet4s half;
enum {
size = 8,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet2ui> : neon_unpacket_default<Packet2ui, uint32_t> {};
template <>
struct unpacket_traits<Packet4ui> : neon_unpacket_default<Packet4ui, uint32_t> {
using half = Packet2ui;
};
template <>
struct unpacket_traits<Packet4us> {
typedef uint16_t type;
typedef Packet4us half;
enum {
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
struct unpacket_traits<Packet2l> : neon_unpacket_default<Packet2l, int64_t> {};
template <>
struct unpacket_traits<Packet8us> {
typedef uint16_t type;
typedef Packet4us half;
enum {
size = 8,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet2i> {
typedef int32_t type;
typedef Packet2i half;
enum {
size = 2,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet4i> {
typedef int32_t type;
typedef Packet2i half;
enum {
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet2ui> {
typedef uint32_t type;
typedef Packet2ui half;
enum {
size = 2,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet4ui> {
typedef uint32_t type;
typedef Packet2ui half;
enum {
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet2l> {
typedef int64_t type;
typedef Packet2l half;
enum {
size = 2,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
template <>
struct unpacket_traits<Packet2ul> {
typedef uint64_t type;
typedef Packet2ul half;
enum {
size = 2,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
struct unpacket_traits<Packet2ul> : neon_unpacket_default<Packet2ul, uint64_t> {};
template <>
EIGEN_STRONG_INLINE Packet2f pzero(const Packet2f& /*a*/) {
@@ -1287,6 +1138,14 @@ template <>
EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c) {
return vfma_f32(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
return vfmsq_f32(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet2f pnmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c) {
return vfms_f32(c, a, b);
}
#else
template <>
EIGEN_STRONG_INLINE Packet4f pmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
@@ -1296,7 +1155,31 @@ template <>
EIGEN_STRONG_INLINE Packet2f pmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c) {
return vmla_f32(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet4f pnmadd(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
return vmlsq_f32(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet2f pnmadd(const Packet2f& a, const Packet2f& b, const Packet2f& c) {
return vmls_f32(c, a, b);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet4f pmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
return pnegate(pnmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet2f pmsub(const Packet2f& a, const Packet2f& b, const Packet2f& c) {
return pnegate(pnmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet4f pnmsub(const Packet4f& a, const Packet4f& b, const Packet4f& c) {
return pnegate(pmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet2f pnmsub(const Packet2f& a, const Packet2f& b, const Packet2f& c) {
return pnegate(pmadd(a, b, c));
}
// No FMA instruction for int, so use MLA unconditionally.
template <>
@@ -2385,10 +2268,12 @@ EIGEN_STRONG_INLINE Packet2ul plogical_shift_left(Packet2ul a) {
template <>
EIGEN_STRONG_INLINE Packet2f pload<Packet2f>(const float* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet2f>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_f32(from);
}
template <>
EIGEN_STRONG_INLINE Packet4f pload<Packet4f>(const float* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet4f>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f32(from);
}
template <>
@@ -2399,10 +2284,12 @@ EIGEN_STRONG_INLINE Packet4c pload<Packet4c>(const int8_t* from) {
}
template <>
EIGEN_STRONG_INLINE Packet8c pload<Packet8c>(const int8_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet8c>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_s8(from);
}
template <>
EIGEN_STRONG_INLINE Packet16c pload<Packet16c>(const int8_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet16c>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s8(from);
}
template <>
@@ -2413,50 +2300,62 @@ EIGEN_STRONG_INLINE Packet4uc pload<Packet4uc>(const uint8_t* from) {
}
template <>
EIGEN_STRONG_INLINE Packet8uc pload<Packet8uc>(const uint8_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet8uc>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_u8(from);
}
template <>
EIGEN_STRONG_INLINE Packet16uc pload<Packet16uc>(const uint8_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet16uc>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u8(from);
}
template <>
EIGEN_STRONG_INLINE Packet4s pload<Packet4s>(const int16_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet4s>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_s16(from);
}
template <>
EIGEN_STRONG_INLINE Packet8s pload<Packet8s>(const int16_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet8s>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s16(from);
}
template <>
EIGEN_STRONG_INLINE Packet4us pload<Packet4us>(const uint16_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet4us>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_u16(from);
}
template <>
EIGEN_STRONG_INLINE Packet8us pload<Packet8us>(const uint16_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet8us>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u16(from);
}
template <>
EIGEN_STRONG_INLINE Packet2i pload<Packet2i>(const int32_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet2i>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_s32(from);
}
template <>
EIGEN_STRONG_INLINE Packet4i pload<Packet4i>(const int32_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet4i>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s32(from);
}
template <>
EIGEN_STRONG_INLINE Packet2ui pload<Packet2ui>(const uint32_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet2ui>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_u32(from);
}
template <>
EIGEN_STRONG_INLINE Packet4ui pload<Packet4ui>(const uint32_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet4ui>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u32(from);
}
template <>
EIGEN_STRONG_INLINE Packet2l pload<Packet2l>(const int64_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet2l>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_s64(from);
}
template <>
EIGEN_STRONG_INLINE Packet2ul pload<Packet2ul>(const uint64_t* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet2ul>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_u64(from);
}
@@ -2681,10 +2580,12 @@ EIGEN_STRONG_INLINE Packet4ui ploadquad<Packet4ui>(const uint32_t* from) {
template <>
EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet2f& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet2f>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_f32(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<float>(float* to, const Packet4f& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet4f>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_f32(to, from);
}
template <>
@@ -2693,10 +2594,12 @@ EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet4c& from) {
}
template <>
EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet8c& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet8c>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_s8(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int8_t>(int8_t* to, const Packet16c& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet16c>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_s8(to, from);
}
template <>
@@ -2705,50 +2608,62 @@ EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet4uc& from) {
}
template <>
EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet8uc& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet8uc>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_u8(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<uint8_t>(uint8_t* to, const Packet16uc& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet16uc>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_u8(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int16_t>(int16_t* to, const Packet4s& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet4s>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_s16(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int16_t>(int16_t* to, const Packet8s& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet8s>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_s16(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<uint16_t>(uint16_t* to, const Packet4us& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet4us>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_u16(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<uint16_t>(uint16_t* to, const Packet8us& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet8us>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_u16(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet2i& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet2i>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_s32(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int32_t>(int32_t* to, const Packet4i& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet4i>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_s32(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet2ui& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet2ui>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_u32(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<uint32_t>(uint32_t* to, const Packet4ui& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet4ui>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_u32(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<int64_t>(int64_t* to, const Packet2l& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet2l>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_s64(to, from);
}
template <>
EIGEN_STRONG_INLINE void pstore<uint64_t>(uint64_t* to, const Packet2ul& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet2ul>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_u64(to, from);
}
@@ -4769,17 +4684,7 @@ struct packet_traits<bfloat16> : default_packet_traits {
};
template <>
struct unpacket_traits<Packet4bf> {
typedef bfloat16 type;
typedef Packet4bf half;
enum {
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
struct unpacket_traits<Packet4bf> : neon_unpacket_default<Packet4bf, bfloat16> {};
namespace detail {
template <>
@@ -4834,6 +4739,7 @@ EIGEN_STRONG_INLINE bfloat16 pfirst<Packet4bf>(const Packet4bf& from) {
template <>
EIGEN_STRONG_INLINE Packet4bf pload<Packet4bf>(const bfloat16* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet4bf>::alignment);
return Packet4bf(pload<Packet4us>(reinterpret_cast<const uint16_t*>(from)));
}
@@ -4844,6 +4750,7 @@ EIGEN_STRONG_INLINE Packet4bf ploadu<Packet4bf>(const bfloat16* from) {
template <>
EIGEN_STRONG_INLINE void pstore<bfloat16>(bfloat16* to, const Packet4bf& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet4bf>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_u16(reinterpret_cast<uint16_t*>(to), from);
}
@@ -5154,6 +5061,7 @@ struct packet_traits<double> : default_packet_traits {
#if EIGEN_ARCH_ARM64 && !EIGEN_APPLE_DOUBLE_NEON_BUG
HasExp = 1,
HasLog = 1,
HasPow = 1,
HasATan = 1,
HasATanh = 1,
#endif
@@ -5169,17 +5077,8 @@ struct packet_traits<double> : default_packet_traits {
};
template <>
struct unpacket_traits<Packet2d> {
typedef double type;
typedef Packet2d half;
typedef Packet2l integer_packet;
enum {
size = 2,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet2d> : neon_unpacket_default<Packet2d, double> {
using integer_packet = Packet2l;
};
template <>
@@ -5242,13 +5141,28 @@ template <>
EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) {
return vfmaq_f64(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) {
return vfmsq_f64(c, a, b);
}
#else
template <>
EIGEN_STRONG_INLINE Packet2d pmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) {
return vmlaq_f64(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet2d pnmadd(const Packet2d& a, const Packet2d& b, const Packet2d& c) {
return vmlsq_f64(c, a, b);
}
#endif
template <>
EIGEN_STRONG_INLINE Packet2d pmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) {
return pnegate(pnmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet2d pnmsub(const Packet2d& a, const Packet2d& b, const Packet2d& c) {
return pnegate(pmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet2d pmin<Packet2d>(const Packet2d& a, const Packet2d& b) {
return vminq_f64(a, b);
@@ -5326,6 +5240,7 @@ EIGEN_STRONG_INLINE Packet2d pcmp_eq(const Packet2d& a, const Packet2d& b) {
template <>
EIGEN_STRONG_INLINE Packet2d pload<Packet2d>(const double* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet2d>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f64(from);
}
@@ -5340,6 +5255,7 @@ EIGEN_STRONG_INLINE Packet2d ploaddup<Packet2d>(const double* from) {
}
template <>
EIGEN_STRONG_INLINE void pstore<double>(double* to, const Packet2d& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet2d>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_f64(to, from);
}
@@ -5532,29 +5448,10 @@ struct packet_traits<Eigen::half> : default_packet_traits {
};
template <>
struct unpacket_traits<Packet4hf> {
typedef Eigen::half type;
typedef Packet4hf half;
enum {
size = 4,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
};
struct unpacket_traits<Packet4hf> : neon_unpacket_default<Packet4hf, half> {};
template <>
struct unpacket_traits<Packet8hf> {
typedef Eigen::half type;
typedef Packet4hf half;
enum {
size = 8,
alignment = Aligned16,
vectorizable = true,
masked_load_available = false,
masked_store_available = false
};
struct unpacket_traits<Packet8hf> : neon_unpacket_default<Packet8hf, half> {
using half = Packet4hf;
};
template <>
@@ -5657,18 +5554,33 @@ EIGEN_STRONG_INLINE Packet4hf pmadd(const Packet4hf& a, const Packet4hf& b, cons
}
template <>
EIGEN_STRONG_INLINE Packet8hf pmsub(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
return vfmaq_f16(pnegate(c), a, b);
EIGEN_STRONG_INLINE Packet8hf pnmadd(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
return vfmsq_f16(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet4hf pnmadd(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
return vfma_f16(c, pnegate(a), b);
return vfms_f16(c, a, b);
}
template <>
EIGEN_STRONG_INLINE Packet8hf pmsub(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
return pnegate(pnmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet4hf pmsub(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
return pnegate(pnmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet8hf pnmsub(const Packet8hf& a, const Packet8hf& b, const Packet8hf& c) {
return pnegate(pmadd(a, b, c));
}
template <>
EIGEN_STRONG_INLINE Packet4hf pnmsub(const Packet4hf& a, const Packet4hf& b, const Packet4hf& c) {
return vfma_f16(pnegate(c), pnegate(a), b);
return pnegate(pmadd(a, b, c));
}
template <>
@@ -5872,11 +5784,13 @@ EIGEN_STRONG_INLINE Packet4hf pandnot<Packet4hf>(const Packet4hf& a, const Packe
template <>
EIGEN_STRONG_INLINE Packet8hf pload<Packet8hf>(const Eigen::half* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet8hf>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1q_f16(reinterpret_cast<const float16_t*>(from));
}
template <>
EIGEN_STRONG_INLINE Packet4hf pload<Packet4hf>(const Eigen::half* from) {
EIGEN_ASSUME_ALIGNED(from, unpacket_traits<Packet4hf>::alignment);
EIGEN_DEBUG_ALIGNED_LOAD return vld1_f16(reinterpret_cast<const float16_t*>(from));
}
@@ -5952,11 +5866,13 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet4hf pinsertlast(const Packet4hf& a,
template <>
EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet8hf& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet8hf>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1q_f16(reinterpret_cast<float16_t*>(to), from);
}
template <>
EIGEN_STRONG_INLINE void pstore<Eigen::half>(Eigen::half* to, const Packet4hf& from) {
EIGEN_ASSUME_ALIGNED(to, unpacket_traits<Packet4hf>::alignment);
EIGEN_DEBUG_ALIGNED_STORE vst1_f16(reinterpret_cast<float16_t*>(to), from);
}

View File

@@ -192,6 +192,7 @@ struct packet_traits<float> : default_packet_traits {
HasExpm1 = 1,
HasNdtri = 1,
HasExp = 1,
HasPow = 1,
HasBessel = 1,
HasSqrt = 1,
HasRsqrt = 1,
@@ -221,6 +222,7 @@ struct packet_traits<double> : default_packet_traits {
HasErf = EIGEN_FAST_MATH,
HasErfc = EIGEN_FAST_MATH,
HasExp = 1,
HasPow = 1,
HasSqrt = 1,
HasRsqrt = 1,
HasCbrt = 1,
@@ -1857,220 +1859,6 @@ EIGEN_STRONG_INLINE void punpackp(Packet4f* vecs) {
vecs[0] = _mm_castsi128_ps(_mm_shuffle_epi32(_mm_castps_si128(vecs[0]), 0x00));
}
template <>
EIGEN_STRONG_INLINE float predux<Packet4f>(const Packet4f& a) {
// Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures
// (from Nehalem to Haswell)
// #ifdef EIGEN_VECTORIZE_SSE3
// Packet4f tmp = _mm_add_ps(a, vec4f_swizzle1(a,2,3,2,3));
// return pfirst<Packet4f>(_mm_hadd_ps(tmp, tmp));
// #else
Packet4f tmp = _mm_add_ps(a, _mm_movehl_ps(a, a));
return pfirst<Packet4f>(_mm_add_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1)));
// #endif
}
template <>
EIGEN_STRONG_INLINE double predux<Packet2d>(const Packet2d& a) {
// Disable SSE3 _mm_hadd_pd that is extremely slow on all existing Intel's architectures
// (from Nehalem to Haswell)
// #ifdef EIGEN_VECTORIZE_SSE3
// return pfirst<Packet2d>(_mm_hadd_pd(a, a));
// #else
return pfirst<Packet2d>(_mm_add_sd(a, _mm_unpackhi_pd(a, a)));
// #endif
}
template <>
EIGEN_STRONG_INLINE int64_t predux<Packet2l>(const Packet2l& a) {
return pfirst<Packet2l>(_mm_add_epi64(a, _mm_unpackhi_epi64(a, a)));
}
#ifdef EIGEN_VECTORIZE_SSSE3
template <>
EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a) {
Packet4i tmp0 = _mm_hadd_epi32(a, a);
return pfirst<Packet4i>(_mm_hadd_epi32(tmp0, tmp0));
}
template <>
EIGEN_STRONG_INLINE uint32_t predux<Packet4ui>(const Packet4ui& a) {
Packet4ui tmp0 = _mm_hadd_epi32(a, a);
return pfirst<Packet4ui>(_mm_hadd_epi32(tmp0, tmp0));
}
#else
template <>
EIGEN_STRONG_INLINE int predux<Packet4i>(const Packet4i& a) {
Packet4i tmp = _mm_add_epi32(a, _mm_unpackhi_epi64(a, a));
return pfirst(tmp) + pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1));
}
template <>
EIGEN_STRONG_INLINE uint32_t predux<Packet4ui>(const Packet4ui& a) {
Packet4ui tmp = _mm_add_epi32(a, _mm_unpackhi_epi64(a, a));
return pfirst(tmp) + pfirst<Packet4ui>(_mm_shuffle_epi32(tmp, 1));
}
#endif
template <>
EIGEN_STRONG_INLINE bool predux<Packet16b>(const Packet16b& a) {
Packet4i tmp = _mm_or_si128(a, _mm_unpackhi_epi64(a, a));
return (pfirst(tmp) != 0) || (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) != 0);
}
// Other reduction functions:
// mul
template <>
EIGEN_STRONG_INLINE float predux_mul<Packet4f>(const Packet4f& a) {
Packet4f tmp = _mm_mul_ps(a, _mm_movehl_ps(a, a));
return pfirst<Packet4f>(_mm_mul_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE double predux_mul<Packet2d>(const Packet2d& a) {
return pfirst<Packet2d>(_mm_mul_sd(a, _mm_unpackhi_pd(a, a)));
}
template <>
EIGEN_STRONG_INLINE int64_t predux_mul<Packet2l>(const Packet2l& a) {
EIGEN_ALIGN16 int64_t aux[2];
pstore(aux, a);
return aux[0] * aux[1];
}
template <>
EIGEN_STRONG_INLINE int predux_mul<Packet4i>(const Packet4i& a) {
// after some experiments, it is seems this is the fastest way to implement it
// for GCC (e.g., reusing pmul is very slow!)
// TODO try to call _mm_mul_epu32 directly
EIGEN_ALIGN16 int aux[4];
pstore(aux, a);
return (aux[0] * aux[1]) * (aux[2] * aux[3]);
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_mul<Packet4ui>(const Packet4ui& a) {
// after some experiments, it is seems this is the fastest way to implement it
// for GCC (eg., reusing pmul is very slow !)
// TODO try to call _mm_mul_epu32 directly
EIGEN_ALIGN16 uint32_t aux[4];
pstore(aux, a);
return (aux[0] * aux[1]) * (aux[2] * aux[3]);
}
template <>
EIGEN_STRONG_INLINE bool predux_mul<Packet16b>(const Packet16b& a) {
Packet4i tmp = _mm_and_si128(a, _mm_unpackhi_epi64(a, a));
return ((pfirst<Packet4i>(tmp) == 0x01010101) && (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) == 0x01010101));
}
// min
template <>
EIGEN_STRONG_INLINE float predux_min<Packet4f>(const Packet4f& a) {
Packet4f tmp = _mm_min_ps(a, _mm_movehl_ps(a, a));
return pfirst<Packet4f>(_mm_min_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE double predux_min<Packet2d>(const Packet2d& a) {
return pfirst<Packet2d>(_mm_min_sd(a, _mm_unpackhi_pd(a, a)));
}
template <>
EIGEN_STRONG_INLINE int predux_min<Packet4i>(const Packet4i& a) {
#ifdef EIGEN_VECTORIZE_SSE4_1
Packet4i tmp = _mm_min_epi32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst<Packet4i>(_mm_min_epi32(tmp, _mm_shuffle_epi32(tmp, 1)));
#else
// after some experiments, it is seems this is the fastest way to implement it
// for GCC (eg., it does not like using std::min after the pstore !!)
EIGEN_ALIGN16 int aux[4];
pstore(aux, a);
int aux0 = aux[0] < aux[1] ? aux[0] : aux[1];
int aux2 = aux[2] < aux[3] ? aux[2] : aux[3];
return aux0 < aux2 ? aux0 : aux2;
#endif // EIGEN_VECTORIZE_SSE4_1
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_min<Packet4ui>(const Packet4ui& a) {
#ifdef EIGEN_VECTORIZE_SSE4_1
Packet4ui tmp = _mm_min_epu32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst<Packet4ui>(_mm_min_epu32(tmp, _mm_shuffle_epi32(tmp, 1)));
#else
// after some experiments, it is seems this is the fastest way to implement it
// for GCC (eg., it does not like using std::min after the pstore !!)
EIGEN_ALIGN16 uint32_t aux[4];
pstore(aux, a);
uint32_t aux0 = aux[0] < aux[1] ? aux[0] : aux[1];
uint32_t aux2 = aux[2] < aux[3] ? aux[2] : aux[3];
return aux0 < aux2 ? aux0 : aux2;
#endif // EIGEN_VECTORIZE_SSE4_1
}
// max
template <>
EIGEN_STRONG_INLINE float predux_max<Packet4f>(const Packet4f& a) {
Packet4f tmp = _mm_max_ps(a, _mm_movehl_ps(a, a));
return pfirst<Packet4f>(_mm_max_ss(tmp, _mm_shuffle_ps(tmp, tmp, 1)));
}
template <>
EIGEN_STRONG_INLINE double predux_max<Packet2d>(const Packet2d& a) {
return pfirst<Packet2d>(_mm_max_sd(a, _mm_unpackhi_pd(a, a)));
}
template <>
EIGEN_STRONG_INLINE int predux_max<Packet4i>(const Packet4i& a) {
#ifdef EIGEN_VECTORIZE_SSE4_1
Packet4i tmp = _mm_max_epi32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst<Packet4i>(_mm_max_epi32(tmp, _mm_shuffle_epi32(tmp, 1)));
#else
// after some experiments, it is seems this is the fastest way to implement it
// for GCC (eg., it does not like using std::min after the pstore !!)
EIGEN_ALIGN16 int aux[4];
pstore(aux, a);
int aux0 = aux[0] > aux[1] ? aux[0] : aux[1];
int aux2 = aux[2] > aux[3] ? aux[2] : aux[3];
return aux0 > aux2 ? aux0 : aux2;
#endif // EIGEN_VECTORIZE_SSE4_1
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_max<Packet4ui>(const Packet4ui& a) {
#ifdef EIGEN_VECTORIZE_SSE4_1
Packet4ui tmp = _mm_max_epu32(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 0, 3, 2)));
return pfirst<Packet4ui>(_mm_max_epu32(tmp, _mm_shuffle_epi32(tmp, 1)));
#else
// after some experiments, it is seems this is the fastest way to implement it
// for GCC (eg., it does not like using std::min after the pstore !!)
EIGEN_ALIGN16 uint32_t aux[4];
pstore(aux, a);
uint32_t aux0 = aux[0] > aux[1] ? aux[0] : aux[1];
uint32_t aux2 = aux[2] > aux[3] ? aux[2] : aux[3];
return aux0 > aux2 ? aux0 : aux2;
#endif // EIGEN_VECTORIZE_SSE4_1
}
// not needed yet
// template<> EIGEN_STRONG_INLINE bool predux_all(const Packet4f& x)
// {
// return _mm_movemask_ps(x) == 0xF;
// }
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet2d& x) {
return _mm_movemask_pd(x) != 0x0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4f& x) {
return _mm_movemask_ps(x) != 0x0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet2l& x) {
return _mm_movemask_pd(_mm_castsi128_pd(x)) != 0x0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4i& x) {
return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0;
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4ui& x) {
return _mm_movemask_ps(_mm_castsi128_ps(x)) != 0x0;
}
EIGEN_STRONG_INLINE void ptranspose(PacketBlock<Packet4f, 4>& kernel) {
_MM_TRANSPOSE4_PS(kernel.packet[0], kernel.packet[1], kernel.packet[2], kernel.packet[3]);
}
@@ -2238,38 +2026,38 @@ EIGEN_STRONG_INLINE Packet2d pblend(const Selector<2>& ifPacket, const Packet2d&
}
// Scalar path for pmadd with FMA to ensure consistency with vectorized path.
#ifdef EIGEN_VECTORIZE_FMA
#if defined(EIGEN_VECTORIZE_FMA)
template <>
EIGEN_STRONG_INLINE float pmadd(const float& a, const float& b, const float& c) {
return ::fmaf(a, b, c);
return std::fmaf(a, b, c);
}
template <>
EIGEN_STRONG_INLINE double pmadd(const double& a, const double& b, const double& c) {
return ::fma(a, b, c);
return std::fma(a, b, c);
}
template <>
EIGEN_STRONG_INLINE float pmsub(const float& a, const float& b, const float& c) {
return ::fmaf(a, b, -c);
return std::fmaf(a, b, -c);
}
template <>
EIGEN_STRONG_INLINE double pmsub(const double& a, const double& b, const double& c) {
return ::fma(a, b, -c);
return std::fma(a, b, -c);
}
template <>
EIGEN_STRONG_INLINE float pnmadd(const float& a, const float& b, const float& c) {
return ::fmaf(-a, b, c);
return std::fmaf(-a, b, c);
}
template <>
EIGEN_STRONG_INLINE double pnmadd(const double& a, const double& b, const double& c) {
return ::fma(-a, b, c);
return std::fma(-a, b, c);
}
template <>
EIGEN_STRONG_INLINE float pnmsub(const float& a, const float& b, const float& c) {
return ::fmaf(-a, b, -c);
return std::fmaf(-a, b, -c);
}
template <>
EIGEN_STRONG_INLINE double pnmsub(const double& a, const double& b, const double& c) {
return ::fma(-a, b, -c);
return std::fma(-a, b, -c);
}
#endif

View File

@@ -0,0 +1,324 @@
// This file is part of Eigen, a lightweight C++ template library
// for linear algebra.
//
// Copyright (C) 2025 Charlie Schlosser <cs.schlosser@gmail.com>
//
// This Source Code Form is subject to the terms of the Mozilla
// Public License v. 2.0. If a copy of the MPL was not distributed
// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
#ifndef EIGEN_REDUCTIONS_SSE_H
#define EIGEN_REDUCTIONS_SSE_H
// IWYU pragma: private
#include "../../InternalHeaderCheck.h"
namespace Eigen {
namespace internal {
template <typename Packet>
struct sse_add_wrapper {
static EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) { return padd<Packet>(a, b); }
};
template <typename Packet>
struct sse_mul_wrapper {
static EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) { return pmul<Packet>(a, b); }
};
template <typename Packet>
struct sse_min_wrapper {
static EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) { return pmin<Packet>(a, b); }
};
template <int NaNPropagation, typename Packet>
struct sse_min_prop_wrapper {
static EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) {
return pmin<NaNPropagation, Packet>(a, b);
}
};
template <typename Packet>
struct sse_max_wrapper {
static EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) { return pmax<Packet>(a, b); }
};
template <int NaNPropagation, typename Packet>
struct sse_max_prop_wrapper {
static EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, const Packet& b) {
return pmax<NaNPropagation, Packet>(a, b);
}
};
template <typename Packet, typename Op>
struct sse_predux_common;
template <typename Packet>
struct sse_predux_impl : sse_predux_common<Packet, sse_add_wrapper<Packet>> {};
template <typename Packet>
struct sse_predux_mul_impl : sse_predux_common<Packet, sse_mul_wrapper<Packet>> {};
template <typename Packet>
struct sse_predux_min_impl : sse_predux_common<Packet, sse_min_wrapper<Packet>> {};
template <int NaNPropagation, typename Packet>
struct sse_predux_min_prop_impl : sse_predux_common<Packet, sse_min_prop_wrapper<NaNPropagation, Packet>> {};
template <typename Packet>
struct sse_predux_max_impl : sse_predux_common<Packet, sse_max_wrapper<Packet>> {};
template <int NaNPropagation, typename Packet>
struct sse_predux_max_prop_impl : sse_predux_common<Packet, sse_max_prop_wrapper<NaNPropagation, Packet>> {};
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet16b -- -- -- -- -- -- -- -- -- -- -- -- */
template <>
EIGEN_STRONG_INLINE bool predux(const Packet16b& a) {
Packet4i tmp = _mm_or_si128(a, _mm_unpackhi_epi64(a, a));
return (pfirst(tmp) != 0) || (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) != 0);
}
template <>
EIGEN_STRONG_INLINE bool predux_mul(const Packet16b& a) {
Packet4i tmp = _mm_and_si128(a, _mm_unpackhi_epi64(a, a));
return ((pfirst<Packet4i>(tmp) == 0x01010101) && (pfirst<Packet4i>(_mm_shuffle_epi32(tmp, 1)) == 0x01010101));
}
template <>
EIGEN_STRONG_INLINE bool predux_min(const Packet16b& a) {
return predux_mul(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_max(const Packet16b& a) {
return predux(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet16b& a) {
return predux(a);
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4i -- -- -- -- -- -- -- -- -- -- -- -- */
template <typename Op>
struct sse_predux_common<Packet4i, Op> {
static EIGEN_STRONG_INLINE int run(const Packet4i& a) {
Packet4i tmp;
tmp = Op::packetOp(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 1, 2, 3)));
tmp = Op::packetOp(tmp, _mm_unpackhi_epi32(tmp, tmp));
return _mm_cvtsi128_si32(tmp);
}
};
template <>
EIGEN_STRONG_INLINE int predux(const Packet4i& a) {
return sse_predux_impl<Packet4i>::run(a);
}
template <>
EIGEN_STRONG_INLINE int predux_mul(const Packet4i& a) {
return sse_predux_mul_impl<Packet4i>::run(a);
}
#ifdef EIGEN_VECTORIZE_SSE4_1
template <>
EIGEN_STRONG_INLINE int predux_min(const Packet4i& a) {
return sse_predux_min_impl<Packet4i>::run(a);
}
template <>
EIGEN_STRONG_INLINE int predux_max(const Packet4i& a) {
return sse_predux_max_impl<Packet4i>::run(a);
}
#endif
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4i& a) {
return _mm_movemask_ps(_mm_castsi128_ps(a)) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4ui -- -- -- -- -- -- -- -- -- -- -- -- */
template <typename Op>
struct sse_predux_common<Packet4ui, Op> {
static EIGEN_STRONG_INLINE uint32_t run(const Packet4ui& a) {
Packet4ui tmp;
tmp = Op::packetOp(a, _mm_shuffle_epi32(a, _MM_SHUFFLE(0, 1, 2, 3)));
tmp = Op::packetOp(tmp, _mm_unpackhi_epi32(tmp, tmp));
return static_cast<uint32_t>(_mm_cvtsi128_si32(tmp));
}
};
template <>
EIGEN_STRONG_INLINE uint32_t predux(const Packet4ui& a) {
return sse_predux_impl<Packet4ui>::run(a);
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_mul(const Packet4ui& a) {
return sse_predux_mul_impl<Packet4ui>::run(a);
}
#ifdef EIGEN_VECTORIZE_SSE4_1
template <>
EIGEN_STRONG_INLINE uint32_t predux_min(const Packet4ui& a) {
return sse_predux_min_impl<Packet4ui>::run(a);
}
template <>
EIGEN_STRONG_INLINE uint32_t predux_max(const Packet4ui& a) {
return sse_predux_max_impl<Packet4ui>::run(a);
}
#endif
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4ui& a) {
return _mm_movemask_ps(_mm_castsi128_ps(a)) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet2l -- -- -- -- -- -- -- -- -- -- -- -- */
template <typename Op>
struct sse_predux_common<Packet2l, Op> {
static EIGEN_STRONG_INLINE int64_t run(const Packet2l& a) {
Packet2l tmp;
tmp = Op::packetOp(a, _mm_unpackhi_epi64(a, a));
return pfirst(tmp);
}
};
template <>
EIGEN_STRONG_INLINE int64_t predux(const Packet2l& a) {
return sse_predux_impl<Packet2l>::run(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet2l& a) {
return _mm_movemask_pd(_mm_castsi128_pd(a)) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet4f -- -- -- -- -- -- -- -- -- -- -- -- */
template <typename Op>
struct sse_predux_common<Packet4f, Op> {
static EIGEN_STRONG_INLINE float run(const Packet4f& a) {
Packet4f tmp;
tmp = Op::packetOp(a, _mm_movehl_ps(a, a));
#ifdef EIGEN_VECTORIZE_SSE3
tmp = Op::packetOp(tmp, _mm_movehdup_ps(tmp));
#else
tmp = Op::packetOp(tmp, _mm_shuffle_ps(tmp, tmp, 1));
#endif
return _mm_cvtss_f32(tmp);
}
};
template <>
EIGEN_STRONG_INLINE float predux(const Packet4f& a) {
return sse_predux_impl<Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_mul(const Packet4f& a) {
return sse_predux_mul_impl<Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min(const Packet4f& a) {
return sse_predux_min_impl<Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateNumbers>(const Packet4f& a) {
return sse_predux_min_prop_impl<PropagateNumbers, Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_min<PropagateNaN>(const Packet4f& a) {
return sse_predux_min_prop_impl<PropagateNaN, Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max(const Packet4f& a) {
return sse_predux_max_impl<Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateNumbers>(const Packet4f& a) {
return sse_predux_max_prop_impl<PropagateNumbers, Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE float predux_max<PropagateNaN>(const Packet4f& a) {
return sse_predux_max_prop_impl<PropagateNaN, Packet4f>::run(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet4f& a) {
return _mm_movemask_ps(a) != 0x0;
}
/* -- -- -- -- -- -- -- -- -- -- -- -- Packet2d -- -- -- -- -- -- -- -- -- -- -- -- */
template <typename Op>
struct sse_predux_common<Packet2d, Op> {
static EIGEN_STRONG_INLINE double run(const Packet2d& a) {
Packet2d tmp;
tmp = Op::packetOp(a, _mm_unpackhi_pd(a, a));
return _mm_cvtsd_f64(tmp);
}
};
template <>
EIGEN_STRONG_INLINE double predux(const Packet2d& a) {
return sse_predux_impl<Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_mul(const Packet2d& a) {
return sse_predux_mul_impl<Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min(const Packet2d& a) {
return sse_predux_min_impl<Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateNumbers>(const Packet2d& a) {
return sse_predux_min_prop_impl<PropagateNumbers, Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_min<PropagateNaN>(const Packet2d& a) {
return sse_predux_min_prop_impl<PropagateNaN, Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max(const Packet2d& a) {
return sse_predux_max_impl<Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateNumbers>(const Packet2d& a) {
return sse_predux_max_prop_impl<PropagateNumbers, Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE double predux_max<PropagateNaN>(const Packet2d& a) {
return sse_predux_max_prop_impl<PropagateNaN, Packet2d>::run(a);
}
template <>
EIGEN_STRONG_INLINE bool predux_any(const Packet2d& a) {
return _mm_movemask_pd(a) != 0x0;
}
} // end namespace internal
} // end namespace Eigen
#endif // EIGEN_REDUCTIONS_SSE_H

View File

@@ -362,11 +362,7 @@ template <typename Scalar, typename Exponent>
struct functor_traits<scalar_pow_op<Scalar, Exponent>> {
enum {
Cost = 5 * NumTraits<Scalar>::MulCost,
PacketAccess = (!NumTraits<Scalar>::IsComplex && !NumTraits<Scalar>::IsInteger && packet_traits<Scalar>::HasExp &&
packet_traits<Scalar>::HasLog && packet_traits<Scalar>::HasRound && packet_traits<Scalar>::HasCmp &&
// Temporarily disable packet access for half/bfloat16 until
// accuracy is improved.
!is_same<Scalar, half>::value && !is_same<Scalar, bfloat16>::value)
PacketAccess = (!NumTraits<Scalar>::IsComplex && !NumTraits<Scalar>::IsInteger && packet_traits<Scalar>::HasPow)
};
};

View File

@@ -164,6 +164,11 @@ struct selfadjoint_product_impl<Lhs, LhsMode, false, Rhs, 0, true> {
enum { LhsUpLo = LhsMode & (Upper | Lower) };
// Verify that the Rhs is a vector in the correct orientation.
// Otherwise, we break the assumption that we are multiplying
// MxN * Nx1.
static_assert(Rhs::ColsAtCompileTime == 1, "The RHS must be a column vector.");
template <typename Dest>
static EIGEN_DEVICE_FUNC void run(Dest& dest, const Lhs& a_lhs, const Rhs& a_rhs, const Scalar& alpha) {
typedef typename Dest::Scalar ResScalar;

View File

@@ -235,7 +235,7 @@
#define EIGEN_VECTORIZE_SSE4_2
#endif
#ifdef __AVX__
#ifndef EIGEN_USE_SYCL
#if !defined(EIGEN_USE_SYCL) && !EIGEN_COMP_EMSCRIPTEN
#define EIGEN_VECTORIZE_AVX
#endif
#define EIGEN_VECTORIZE_SSE3

View File

@@ -171,6 +171,8 @@ template <typename MatrixType, unsigned int Mode>
class TriangularView;
template <typename MatrixType, unsigned int Mode>
class SelfAdjointView;
template <typename Derived>
class RealView;
template <typename MatrixType>
class SparseView;
template <typename ExpressionType>
@@ -397,8 +399,6 @@ template <typename Scalar_, int Rows_, int Cols_,
: EIGEN_DEFAULT_MATRIX_STORAGE_ORDER_OPTION),
int MaxRows_ = Rows_, int MaxCols_ = Cols_>
class Array;
template <typename ConditionMatrixType, typename ThenMatrixType, typename ElseMatrixType>
class Select;
template <typename MatrixType, typename BinaryOp, int Direction>
class PartialReduxExpr;
template <typename ExpressionType, int Direction>

View File

@@ -993,8 +993,9 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE constexpr void ignore_unused_variable(cons
#endif
#if !defined(EIGEN_OPTIMIZATION_BARRIER)
#if EIGEN_COMP_GNUC
// According to https://gcc.gnu.org/onlinedocs/gcc/Constraints.html:
// Implement the barrier on GNUC compilers or clang-cl.
#if EIGEN_COMP_GNUC || (defined(__clang__) && defined(_MSC_VER))
// According to https://gcc.gnu.org/onlinedocs/gcc/Constraints.html:
// X: Any operand whatsoever.
// r: A register operand is allowed provided that it is in a general
// register.
@@ -1027,37 +1028,37 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE constexpr void ignore_unused_variable(cons
// directly for std::complex<T>, Eigen::half, Eigen::bfloat16. For these,
// you will need to apply to the underlying POD type.
#if EIGEN_ARCH_PPC && EIGEN_COMP_GNUC_STRICT
// This seems to be broken on clang. Packet4f is loaded into a single
// This seems to be broken on clang. Packet4f is loaded into a single
// register rather than a vector, zeroing out some entries. Integer
// types also generate a compile error.
#if EIGEN_OS_MAC
// General, Altivec for Apple (VSX were added in ISA v2.06):
// General, Altivec for Apple (VSX were added in ISA v2.06):
#define EIGEN_OPTIMIZATION_BARRIER(X) __asm__("" : "+r,v"(X));
#else
// General, Altivec, VSX otherwise:
// General, Altivec, VSX otherwise:
#define EIGEN_OPTIMIZATION_BARRIER(X) __asm__("" : "+r,v,wa"(X));
#endif
#elif EIGEN_ARCH_ARM_OR_ARM64
#ifdef __ARM_FP
// General, VFP or NEON.
// General, VFP or NEON.
// Clang doesn't like "r",
// error: non-trivial scalar-to-vector conversion, possible invalid
// constraint for vector typ
#define EIGEN_OPTIMIZATION_BARRIER(X) __asm__("" : "+g,w"(X));
#else
// Arm without VFP or NEON.
// Arm without VFP or NEON.
// "w" constraint will not compile.
#define EIGEN_OPTIMIZATION_BARRIER(X) __asm__("" : "+g"(X));
#endif
#elif EIGEN_ARCH_i386_OR_x86_64
// General, SSE.
// General, SSE.
#define EIGEN_OPTIMIZATION_BARRIER(X) __asm__("" : "+g,x"(X));
#else
// Not implemented for other architectures.
// Not implemented for other architectures.
#define EIGEN_OPTIMIZATION_BARRIER(X)
#endif
#else
// Not implemented for other compilers.
// Not implemented for other compilers.
#define EIGEN_OPTIMIZATION_BARRIER(X)
#endif
#endif

View File

@@ -762,7 +762,7 @@ void swap(scoped_array<T>& a, scoped_array<T>& b) {
* This is accomplished through alloca if this later is supported and if the required number of bytes
* is below EIGEN_STACK_ALLOCATION_LIMIT.
*/
#ifdef EIGEN_ALLOCA
#if defined(EIGEN_ALLOCA) && !defined(EIGEN_NO_ALLOCA)
#if EIGEN_DEFAULT_ALIGN_BYTES > 0
// We always manually re-align the result of EIGEN_ALLOCA.
@@ -785,14 +785,14 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void* eigen_aligned_alloca_helper(void* pt
#define EIGEN_ALIGNED_ALLOCA(SIZE) EIGEN_ALLOCA(SIZE)
#endif
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER) \
Eigen::internal::check_size_for_overflow<TYPE>(SIZE); \
TYPE* NAME = (BUFFER) != 0 ? (BUFFER) \
: reinterpret_cast<TYPE*>((sizeof(TYPE) * SIZE <= EIGEN_STACK_ALLOCATION_LIMIT) \
? EIGEN_ALIGNED_ALLOCA(sizeof(TYPE) * SIZE) \
: Eigen::internal::aligned_malloc(sizeof(TYPE) * SIZE)); \
Eigen::internal::aligned_stack_memory_handler<TYPE> EIGEN_CAT(NAME, _stack_memory_destructor)( \
(BUFFER) == 0 ? NAME : 0, SIZE, sizeof(TYPE) * SIZE > EIGEN_STACK_ALLOCATION_LIMIT)
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER) \
Eigen::internal::check_size_for_overflow<TYPE>(SIZE); \
TYPE* NAME = (BUFFER) != 0 ? (BUFFER) \
: reinterpret_cast<TYPE*>((sizeof(TYPE) * (SIZE) <= EIGEN_STACK_ALLOCATION_LIMIT) \
? EIGEN_ALIGNED_ALLOCA(sizeof(TYPE) * (SIZE)) \
: Eigen::internal::aligned_malloc(sizeof(TYPE) * (SIZE))); \
Eigen::internal::aligned_stack_memory_handler<TYPE> EIGEN_CAT(NAME, _stack_memory_destructor)( \
(BUFFER) == 0 ? NAME : 0, SIZE, sizeof(TYPE) * (SIZE) > EIGEN_STACK_ALLOCATION_LIMIT)
#define ei_declare_local_nested_eval(XPR_T, XPR, N, NAME) \
Eigen::internal::local_nested_eval_wrapper<XPR_T, N> EIGEN_CAT(NAME, _wrapper)( \
@@ -805,10 +805,11 @@ EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void* eigen_aligned_alloca_helper(void* pt
#else
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER) \
Eigen::internal::check_size_for_overflow<TYPE>(SIZE); \
TYPE* NAME = (BUFFER) != 0 ? BUFFER : reinterpret_cast<TYPE*>(Eigen::internal::aligned_malloc(sizeof(TYPE) * SIZE)); \
Eigen::internal::aligned_stack_memory_handler<TYPE> EIGEN_CAT(NAME, _stack_memory_destructor)( \
#define ei_declare_aligned_stack_constructed_variable(TYPE, NAME, SIZE, BUFFER) \
Eigen::internal::check_size_for_overflow<TYPE>(SIZE); \
TYPE* NAME = \
(BUFFER) != 0 ? BUFFER : reinterpret_cast<TYPE*>(Eigen::internal::aligned_malloc(sizeof(TYPE) * (SIZE))); \
Eigen::internal::aligned_stack_memory_handler<TYPE> EIGEN_CAT(NAME, _stack_memory_destructor)( \
(BUFFER) == 0 ? NAME : 0, SIZE, true)
#define ei_declare_local_nested_eval(XPR_T, XPR, N, NAME) \
@@ -1338,6 +1339,21 @@ EIGEN_DEVICE_FUNC void destroy_at(T* p) {
}
#endif
/** \internal
* This informs the implementation that PTR is aligned to at least ALIGN_BYTES
*/
#ifndef EIGEN_ASSUME_ALIGNED
#if defined(__cpp_lib_assume_aligned) && (__cpp_lib_assume_aligned >= 201811L)
#define EIGEN_ASSUME_ALIGNED(PTR, ALIGN_BYTES) \
{ PTR = std::assume_aligned<8 * (ALIGN_BYTES)>(PTR); }
#elif EIGEN_HAS_BUILTIN(__builtin_assume_aligned)
#define EIGEN_ASSUME_ALIGNED(PTR, ALIGN_BYTES) \
{ PTR = static_cast<decltype(PTR)>(__builtin_assume_aligned(PTR, (ALIGN_BYTES))); }
#else
#define EIGEN_ASSUME_ALIGNED(PTR, ALIGN_BYTES) /* do nothing */
#endif
#endif
} // end namespace internal
} // end namespace Eigen

View File

@@ -85,6 +85,29 @@ class QuaternionBase : public RotationBase<Derived, 3> {
return derived().coeffs();
}
/** \returns a vector containing the coefficients, rearranged into the order [\c w, \c x, \c y, \c z].
*
* This is the order expected by the \code Quaternion(const Scalar& w, const Scalar& x, const Scalar& y, const Scalar&
* z) \endcode constructor, but not the order of the internal vector representation. Therefore, it returns a newly
* constructed vector.
*
* \sa QuaternionBase::coeffsScalarLast()
* */
EIGEN_DEVICE_FUNC inline typename internal::traits<Derived>::Coefficients coeffsScalarFirst() const {
return derived().coeffsScalarFirst();
}
/** \returns a vector containing the coefficients in their original order [\c x, \c y, \c z, \c w].
*
* This is equivalent to \code coeffs() \endcode, but returns a newly constructed vector for uniformity with \code
* coeffsScalarFirst() \endcode.
*
* \sa QuaternionBase::coeffsScalarFirst()
* */
EIGEN_DEVICE_FUNC inline typename internal::traits<Derived>::Coefficients coeffsScalarLast() const {
return derived().coeffsScalarLast();
}
/** \returns a vector expression of the coefficients (x,y,z,w) */
EIGEN_DEVICE_FUNC inline typename internal::traits<Derived>::Coefficients& coeffs() { return derived().coeffs(); }
@@ -357,12 +380,23 @@ class Quaternion : public QuaternionBase<Quaternion<Scalar_, Options_> > {
EIGEN_DEVICE_FUNC static Quaternion UnitRandom();
EIGEN_DEVICE_FUNC static Quaternion FromCoeffsScalarLast(const Scalar& x, const Scalar& y, const Scalar& z,
const Scalar& w);
EIGEN_DEVICE_FUNC static Quaternion FromCoeffsScalarFirst(const Scalar& w, const Scalar& x, const Scalar& y,
const Scalar& z);
template <typename Derived1, typename Derived2>
EIGEN_DEVICE_FUNC static Quaternion FromTwoVectors(const MatrixBase<Derived1>& a, const MatrixBase<Derived2>& b);
EIGEN_DEVICE_FUNC inline Coefficients& coeffs() { return m_coeffs; }
EIGEN_DEVICE_FUNC inline const Coefficients& coeffs() const { return m_coeffs; }
EIGEN_DEVICE_FUNC inline Coefficients coeffsScalarLast() const { return m_coeffs; }
EIGEN_DEVICE_FUNC inline Coefficients coeffsScalarFirst() const {
return {m_coeffs.w(), m_coeffs.x(), m_coeffs.y(), m_coeffs.z()};
}
EIGEN_MAKE_ALIGNED_OPERATOR_NEW_IF(bool(NeedsAlignment))
#ifdef EIGEN_QUATERNION_PLUGIN
@@ -437,6 +471,12 @@ class Map<const Quaternion<Scalar_>, Options_> : public QuaternionBase<Map<const
EIGEN_DEVICE_FUNC inline const Coefficients& coeffs() const { return m_coeffs; }
EIGEN_DEVICE_FUNC inline Coefficients coeffsScalarLast() const { return m_coeffs; }
EIGEN_DEVICE_FUNC inline Coefficients coeffsScalarFirst() const {
return {m_coeffs.w(), m_coeffs.x(), m_coeffs.y(), m_coeffs.z()};
}
protected:
const Coefficients m_coeffs;
};
@@ -473,6 +513,12 @@ class Map<Quaternion<Scalar_>, Options_> : public QuaternionBase<Map<Quaternion<
EIGEN_DEVICE_FUNC inline Coefficients& coeffs() { return m_coeffs; }
EIGEN_DEVICE_FUNC inline const Coefficients& coeffs() const { return m_coeffs; }
EIGEN_DEVICE_FUNC inline Coefficients coeffsScalarLast() const { return m_coeffs; }
EIGEN_DEVICE_FUNC inline Coefficients coeffsScalarFirst() const {
return {m_coeffs.w(), m_coeffs.x(), m_coeffs.y(), m_coeffs.z()};
}
protected:
Coefficients m_coeffs;
};
@@ -694,6 +740,35 @@ EIGEN_DEVICE_FUNC Quaternion<Scalar, Options> Quaternion<Scalar, Options>::UnitR
return Quaternion(a * sin(u2), a * cos(u2), b * sin(u3), b * cos(u3));
}
/** Constructs a quaternion from its coefficients in the order [\c x, \c y, \c z, \c w], i.e. vector part [\c x, \c y,
* \c z] first, scalar part \a w LAST.
*
* This factory accepts the parameters in the same order as the underlying coefficient vector. Consider using this
* factory function to make the parameter ordering explicit.
*/
template <typename Scalar, int Options>
EIGEN_DEVICE_FUNC Quaternion<Scalar, Options> Quaternion<Scalar, Options>::FromCoeffsScalarLast(const Scalar& x,
const Scalar& y,
const Scalar& z,
const Scalar& w) {
return Quaternion(w, x, y, z);
}
/** Constructs a quaternion from its coefficients in the order [\c w, \c x, \c y, \c z], i.e. scalar part \a w FIRST,
* vector part [\c x, \c y, \c z] last.
*
* This factory accepts the parameters in the same order as the constructor \code Quaternion(const Scalar& w, const
* Scalar& x, const Scalar& y, const Scalar& z) \endcode. Consider using this factory function to make the parameter
* ordering explicit.
*/
template <typename Scalar, int Options>
EIGEN_DEVICE_FUNC Quaternion<Scalar, Options> Quaternion<Scalar, Options>::FromCoeffsScalarFirst(const Scalar& w,
const Scalar& x,
const Scalar& y,
const Scalar& z) {
return Quaternion(w, x, y, z);
}
/** Returns a quaternion representing a rotation between
* the two arbitrary vectors \a a and \a b. In other words, the built
* rotation represent a rotation sending the line of direction \a a

View File

@@ -406,7 +406,7 @@ class SimplicialLLT : public SimplicialCholeskyBase<SimplicialLLT<MatrixType_, U
return *this;
}
/** Performs a symbolic decomposition on the sparcity of \a matrix.
/** Performs a symbolic decomposition on the sparsity of \a matrix.
*
* This function is particularly useful when solving for several problems having the same structure.
*
@@ -416,7 +416,7 @@ class SimplicialLLT : public SimplicialCholeskyBase<SimplicialLLT<MatrixType_, U
/** Performs a numeric decomposition of \a matrix
*
* The given matrix must has the same sparcity than the matrix on which the symbolic decomposition has been performed.
* The given matrix must has the same sparsity than the matrix on which the symbolic decomposition has been performed.
*
* \sa analyzePattern()
*/
@@ -494,7 +494,7 @@ class SimplicialLDLT : public SimplicialCholeskyBase<SimplicialLDLT<MatrixType_,
return *this;
}
/** Performs a symbolic decomposition on the sparcity of \a matrix.
/** Performs a symbolic decomposition on the sparsity of \a matrix.
*
* This function is particularly useful when solving for several problems having the same structure.
*
@@ -504,7 +504,7 @@ class SimplicialLDLT : public SimplicialCholeskyBase<SimplicialLDLT<MatrixType_,
/** Performs a numeric decomposition of \a matrix
*
* The given matrix must has the same sparcity than the matrix on which the symbolic decomposition has been performed.
* The given matrix must has the same sparsity than the matrix on which the symbolic decomposition has been performed.
*
* \sa analyzePattern()
*/
@@ -575,7 +575,7 @@ class SimplicialNonHermitianLLT
return *this;
}
/** Performs a symbolic decomposition on the sparcity of \a matrix.
/** Performs a symbolic decomposition on the sparsity of \a matrix.
*
* This function is particularly useful when solving for several problems having the same structure.
*
@@ -585,7 +585,7 @@ class SimplicialNonHermitianLLT
/** Performs a numeric decomposition of \a matrix
*
* The given matrix must has the same sparcity than the matrix on which the symbolic decomposition has been performed.
* The given matrix must has the same sparsity than the matrix on which the symbolic decomposition has been performed.
*
* \sa analyzePattern()
*/
@@ -664,7 +664,7 @@ class SimplicialNonHermitianLDLT
return *this;
}
/** Performs a symbolic decomposition on the sparcity of \a matrix.
/** Performs a symbolic decomposition on the sparsity of \a matrix.
*
* This function is particularly useful when solving for several problems having the same structure.
*
@@ -674,7 +674,7 @@ class SimplicialNonHermitianLDLT
/** Performs a numeric decomposition of \a matrix
*
* The given matrix must has the same sparcity than the matrix on which the symbolic decomposition has been performed.
* The given matrix must has the same sparsity than the matrix on which the symbolic decomposition has been performed.
*
* \sa analyzePattern()
*/
@@ -742,7 +742,7 @@ class SimplicialCholesky : public SimplicialCholeskyBase<SimplicialCholesky<Matr
return *this;
}
/** Performs a symbolic decomposition on the sparcity of \a matrix.
/** Performs a symbolic decomposition on the sparsity of \a matrix.
*
* This function is particularly useful when solving for several problems having the same structure.
*
@@ -757,7 +757,7 @@ class SimplicialCholesky : public SimplicialCholeskyBase<SimplicialCholesky<Matr
/** Performs a numeric decomposition of \a matrix
*
* The given matrix must has the same sparcity than the matrix on which the symbolic decomposition has been performed.
* The given matrix must has the same sparsity than the matrix on which the symbolic decomposition has been performed.
*
* \sa analyzePattern()
*/

View File

@@ -274,6 +274,10 @@ struct simpl_chol_helper {
}
};
// Symbol is ODR-used, so we need a definition.
template <typename Scalar, typename StorageIndex>
constexpr StorageIndex simpl_chol_helper<Scalar, StorageIndex>::kEmpty;
} // namespace internal
template <typename Derived>

View File

@@ -36,10 +36,10 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
Scalar res1(0);
Scalar res2(0);
for (; i; ++i) {
res1 += numext::conj(i.value()) * other.coeff(i.index());
res1 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res1);
++i;
if (i) {
res2 += numext::conj(i.value()) * other.coeff(i.index());
res2 = numext::madd<Scalar>(numext::conj(i.value()), other.coeff(i.index()), res2);
}
}
return res1 + res2;
@@ -67,7 +67,7 @@ inline typename internal::traits<Derived>::Scalar SparseMatrixBase<Derived>::dot
Scalar res(0);
while (i && j) {
if (i.index() == j.index()) {
res += numext::conj(i.value()) * j.value();
res = numext::madd<Scalar>(numext::conj(i.value()), j.value(), res);
++i;
++j;
} else if (i.index() < j.index())

View File

@@ -1130,7 +1130,11 @@ void set_from_triplets(const InputIterator& begin, const InputIterator& end, Spa
using TransposedSparseMatrix =
SparseMatrix<typename SparseMatrixType::Scalar, IsRowMajor ? ColMajor : RowMajor, StorageIndex>;
if (begin == end) return;
if (begin == end) {
// Clear out existing data (if any).
mat.setZero();
return;
}
// There are two strategies to consider for constructing a matrix from unordered triplets:
// A) construct the 'mat' in its native storage order and sort in-place (less memory); or,

View File

@@ -224,33 +224,87 @@ class SparseMatrixBase : public EigenBase<Derived> {
public:
#ifndef EIGEN_NO_IO
friend std::ostream& operator<<(std::ostream& s, const SparseMatrixBase& m) {
typedef typename Derived::Nested Nested;
typedef internal::remove_all_t<Nested> NestedCleaned;
using Nested = typename Derived::Nested;
using NestedCleaned = typename internal::remove_all<Nested>::type;
/// For converting `0's` to the matrices numerical type
using Scalar = typename Derived::Scalar;
if (Flags & RowMajorBit) {
Nested nm(m.derived());
internal::evaluator<NestedCleaned> thisEval(nm);
// compute global width
std::size_t width = 0;
{
std::ostringstream ss0;
ss0.copyfmt(s);
ss0 << Scalar(0);
width = ss0.str().size();
for (Index row = 0; row < nm.outerSize(); ++row) {
for (typename internal::evaluator<NestedCleaned>::InnerIterator it(thisEval, row); it; ++it) {
std::ostringstream ss;
ss.copyfmt(s);
ss << it.value();
const std::size_t potential_width = ss.str().size();
if (potential_width > width) width = potential_width;
}
}
}
for (Index row = 0; row < nm.outerSize(); ++row) {
Index col = 0;
for (typename internal::evaluator<NestedCleaned>::InnerIterator it(thisEval, row); it; ++it) {
for (; col < it.index(); ++col) s << "0 ";
for (; col < it.index(); ++col) {
s.width(width);
s << Scalar(0) << " ";
}
s.width(width);
s << it.value() << " ";
++col;
}
for (; col < m.cols(); ++col) s << "0 ";
for (; col < m.cols(); ++col) {
s.width(width);
s << Scalar(0) << " ";
}
s << std::endl;
}
} else {
Nested nm(m.derived());
internal::evaluator<NestedCleaned> thisEval(nm);
if (m.cols() == 1) {
// compute local width (single col)
std::size_t width = 0;
{
std::ostringstream ss0;
ss0.copyfmt(s);
ss0 << Scalar(0);
width = ss0.str().size();
for (typename internal::evaluator<NestedCleaned>::InnerIterator it(thisEval, 0); it; ++it) {
std::ostringstream ss;
ss.copyfmt(s);
ss << it.value();
const std::size_t potential_width = ss.str().size();
if (potential_width > width) width = potential_width;
}
}
Index row = 0;
for (typename internal::evaluator<NestedCleaned>::InnerIterator it(thisEval, 0); it; ++it) {
for (; row < it.index(); ++row) s << "0" << std::endl;
for (; row < it.index(); ++row) {
s.width(width);
s << Scalar(0) << std::endl;
}
s.width(width);
s << it.value() << std::endl;
++row;
}
for (; row < m.rows(); ++row) s << "0" << std::endl;
for (; row < m.rows(); ++row) {
s.width(width);
s << Scalar(0) << std::endl;
}
} else {
SparseMatrix<Scalar, RowMajorBit, StorageIndex> trans = m;
s << static_cast<const SparseMatrixBase<SparseMatrix<Scalar, RowMajorBit, StorageIndex> >&>(trans);

View File

@@ -41,7 +41,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, RowMajor> {
lastVal = it.value();
lastIndex = it.index();
if (lastIndex == i) break;
tmp -= lastVal * other.coeff(lastIndex, col);
tmp = numext::madd<Scalar>(-lastVal, other.coeff(lastIndex, col), tmp);
}
if (Mode & UnitDiag)
other.coeffRef(i, col) = tmp;
@@ -75,7 +75,7 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, RowMajor> {
} else if (it && it.index() == i)
++it;
for (; it; ++it) {
tmp -= it.value() * other.coeff(it.index(), col);
tmp = numext::madd<Scalar>(-it.value(), other.coeff(it.index(), col), tmp);
}
if (Mode & UnitDiag)
@@ -107,7 +107,9 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Lower, ColMajor> {
tmp /= it.value();
}
if (it && it.index() == i) ++it;
for (; it; ++it) other.coeffRef(it.index(), col) -= tmp * it.value();
for (; it; ++it) {
other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
}
}
}
}
@@ -135,7 +137,9 @@ struct sparse_solve_triangular_selector<Lhs, Rhs, Mode, Upper, ColMajor> {
other.coeffRef(i, col) /= it.value();
}
LhsIterator it(lhsEval, i);
for (; it && it.index() < i; ++it) other.coeffRef(it.index(), col) -= tmp * it.value();
for (; it && it.index() < i; ++it) {
other.coeffRef(it.index(), col) = numext::madd<Scalar>(-tmp, it.value(), other.coeffRef(it.index(), col));
}
}
}
}
@@ -215,9 +219,13 @@ struct sparse_solve_triangular_sparse_selector<Lhs, Rhs, Mode, UpLo, ColMajor> {
tempVector.restart();
if (IsLower) {
if (it.index() == i) ++it;
for (; it; ++it) tempVector.coeffRef(it.index()) -= ci * it.value();
for (; it; ++it) {
tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
}
} else {
for (; it && it.index() < i; ++it) tempVector.coeffRef(it.index()) -= ci * it.value();
for (; it && it.index() < i; ++it) {
tempVector.coeffRef(it.index()) = numext::madd<Scalar>(-ci, it.value(), tempVector.coeffRef(it.index()));
}
}
}
}

View File

@@ -29,12 +29,12 @@ class LinearQuadraticRegulatorTest {
var qElms = VecBuilder.fill(0.02, 0.4);
var rElms = VecBuilder.fill(12.0);
var dt = 0.00505;
var dt = 0.005;
var K = new LinearQuadraticRegulator<>(plant, qElms, rElms, dt).getK();
assertEquals(522.153, K.get(0, 0), 0.1);
assertEquals(38.2, K.get(0, 1), 0.1);
assertEquals(522.87006795347486, K.get(0, 0), 1e-6);
assertEquals(38.239878385020411, K.get(0, 1), 1e-6);
}
@Test
@@ -65,12 +65,12 @@ class LinearQuadraticRegulatorTest {
var qElms = VecBuilder.fill(0.01745, 0.08726);
var rElms = VecBuilder.fill(12.0);
var dt = 0.00505;
var dt = 0.005;
var K = new LinearQuadraticRegulator<>(plant, qElms, rElms, dt).getK();
assertEquals(19.16, K.get(0, 0), 0.1);
assertEquals(3.32, K.get(0, 1), 0.1);
assertEquals(19.339349883583761, K.get(0, 0), 1e-6);
assertEquals(3.3542559517421582, K.get(0, 1), 1e-6);
}
/**

View File

@@ -22,7 +22,7 @@ import java.util.Random;
import org.junit.jupiter.api.Test;
class LinearSystemLoopTest {
private static final double kDt = 0.00505;
private static final double kDt = 0.005;
private static final double kPositionStddev = 0.0001;
private static final Random random = new Random();
@@ -45,12 +45,12 @@ class LinearSystemLoopTest {
(LinearSystem<N2, N1, N1>) m_plant.slice(0),
VecBuilder.fill(0.02, 0.4),
VecBuilder.fill(12.0),
0.00505);
0.005);
@SuppressWarnings("unchecked")
private final LinearSystemLoop<N2, N1, N1> m_loop =
new LinearSystemLoop<>(
(LinearSystem<N2, N1, N1>) m_plant.slice(0), m_controller, m_observer, 12, 0.00505);
(LinearSystem<N2, N1, N1>) m_plant.slice(0), m_controller, m_observer, 12, 0.005);
private static void updateTwoState(
LinearSystem<N2, N1, N1> plant, LinearSystemLoop<N2, N1, N1> loop, double noise) {

View File

@@ -67,7 +67,7 @@ class ExtendedKalmanFilterTest {
@Test
void testInit() {
double dt = 0.00505;
double dt = 0.005;
assertDoesNotThrow(
() -> {
@@ -97,7 +97,7 @@ class ExtendedKalmanFilterTest {
@Test
void testConvergence() {
double dt = 0.00505;
double dt = 0.005;
double rb = 0.8382 / 2.0; // Robot radius
ExtendedKalmanFilter<N5, N2, N3> observer =

View File

@@ -29,7 +29,7 @@ import org.junit.jupiter.api.Test;
class KalmanFilterTest {
private static LinearSystem<N2, N1, N2> elevatorPlant;
private static final double kDt = 0.00505;
private static final double kDt = 0.005;
static {
createElevator();

View File

@@ -0,0 +1,25 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.trajectory.struct;
import static org.junit.jupiter.api.Assertions.assertEquals;
import edu.wpi.first.math.trajectory.ExponentialProfile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.junit.jupiter.api.Test;
class ExponentialProfileStateStructTest {
private static final ExponentialProfile.State STATE = new ExponentialProfile.State(4.0, 5.0);
@Test
void testRoundtrip() {
ByteBuffer buffer = ByteBuffer.allocate(ExponentialProfile.State.struct.getSize());
buffer.order(ByteOrder.LITTLE_ENDIAN);
ExponentialProfile.State.struct.pack(buffer, STATE);
buffer.rewind();
assertEquals(STATE, ExponentialProfile.State.struct.unpack(buffer));
}
}

View File

@@ -0,0 +1,25 @@
// Copyright (c) FIRST and other WPILib contributors.
// Open Source Software; you can modify and/or share it under the terms of
// the WPILib BSD license file in the root directory of this project.
package edu.wpi.first.math.trajectory.struct;
import static org.junit.jupiter.api.Assertions.assertEquals;
import edu.wpi.first.math.trajectory.TrapezoidProfile;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import org.junit.jupiter.api.Test;
class TrapezoidProfileStateStructTest {
private static final TrapezoidProfile.State STATE = new TrapezoidProfile.State(4.0, 5.0);
@Test
void testRoundtrip() {
ByteBuffer buffer = ByteBuffer.allocate(TrapezoidProfile.State.struct.getSize());
buffer.order(ByteOrder.LITTLE_ENDIAN);
TrapezoidProfile.State.struct.pack(buffer, STATE);
buffer.rewind();
assertEquals(STATE, TrapezoidProfile.State.struct.unpack(buffer));
}
}

View File

@@ -12,7 +12,7 @@ namespace frc {
TEST(DifferentialDriveAccelerationLimiterTest, LowLimits) {
constexpr auto trackwidth = 0.9_m;
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
constexpr auto maxA = 2_mps_sq;
constexpr auto maxAlpha = 2_rad_per_s_sq;
@@ -105,7 +105,7 @@ TEST(DifferentialDriveAccelerationLimiterTest, LowLimits) {
TEST(DifferentialDriveAccelerationLimiterTest, HighLimits) {
constexpr auto trackwidth = 0.9_m;
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
using Kv_t = decltype(1_V / 1_mps);
using Ka_t = decltype(1_V / 1_mps_sq);
@@ -173,7 +173,7 @@ TEST(DifferentialDriveAccelerationLimiterTest, HighLimits) {
TEST(DifferentialDriveAccelerationLimiterTest, SeparateMinMaxLowLimits) {
constexpr auto trackwidth = 0.9_m;
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
constexpr auto minA = -1_mps_sq;
constexpr auto maxA = 2_mps_sq;
constexpr auto maxAlpha = 2_rad_per_s_sq;

View File

@@ -20,7 +20,7 @@ TEST(DifferentialDriveFeedforwardTest, CalculateWithTrackwidth) {
constexpr auto kVAngular = 1_V / 1_rad_per_s;
constexpr auto kAAngular = 1_V / 1_rad_per_s_sq;
constexpr auto trackwidth = 1_m;
constexpr auto dt = 20_ms;
constexpr units::second_t dt = 20_ms;
frc::DifferentialDriveFeedforward differentialDriveFeedforward{
kVLinear, kALinear, kVAngular, kAAngular, trackwidth};
@@ -54,7 +54,7 @@ TEST(DifferentialDriveFeedforwardTest, CalculateWithoutTrackwidth) {
constexpr auto kALinear = 1_V / 1_mps_sq;
constexpr auto kVAngular = 1_V / 1_mps;
constexpr auto kAAngular = 1_V / 1_mps_sq;
constexpr auto dt = 20_ms;
constexpr units::second_t dt = 20_ms;
frc::DifferentialDriveFeedforward differentialDriveFeedforward{
kVLinear, kALinear, kVAngular, kAAngular};

View File

@@ -10,7 +10,7 @@
namespace frc {
TEST(ImplicitModelFollowerTest, SameModel) {
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
using Kv_t = decltype(1_V / 1_mps);
using Ka_t = decltype(1_V / 1_mps_sq);
@@ -54,7 +54,7 @@ TEST(ImplicitModelFollowerTest, SameModel) {
}
TEST(ImplicitModelFollowerTest, SlowerRefModel) {
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
using Kv_t = decltype(1_V / 1_mps);
using Ka_t = decltype(1_V / 1_mps_sq);

View File

@@ -31,10 +31,10 @@ TEST(LinearQuadraticRegulatorTest, ElevatorGains) {
return frc::LinearSystemId::ElevatorSystem(motors, m, r, G).Slice(0);
}();
Matrixd<1, 2> K =
LinearQuadraticRegulator<2, 1>{plant, {0.02, 0.4}, {12.0}, 5.05_ms}.K();
LinearQuadraticRegulator<2, 1>{plant, {0.02, 0.4}, {12.0}, 5_ms}.K();
EXPECT_NEAR(522.15314269, K(0, 0), 1e-6);
EXPECT_NEAR(38.20138596, K(0, 1), 1e-6);
EXPECT_NEAR(522.87006795347486, K(0, 0), 1e-6);
EXPECT_NEAR(38.239878385020411, K(0, 1), 1e-6);
}
TEST(LinearQuadraticRegulatorTest, ArmGains) {
@@ -56,11 +56,11 @@ TEST(LinearQuadraticRegulatorTest, ArmGains) {
}();
Matrixd<1, 2> K =
LinearQuadraticRegulator<2, 1>{plant, {0.01745, 0.08726}, {12.0}, 5.05_ms}
LinearQuadraticRegulator<2, 1>{plant, {0.01745, 0.08726}, {12.0}, 5_ms}
.K();
EXPECT_NEAR(19.16, K(0, 0), 1e-1);
EXPECT_NEAR(3.32, K(0, 1), 1e-1);
EXPECT_NEAR(19.339349883583761, K(0, 0), 1e-6);
EXPECT_NEAR(3.3542559517421582, K(0, 1), 1e-6);
}
TEST(LinearQuadraticRegulatorTest, FourMotorElevator) {

View File

@@ -61,7 +61,7 @@ frc::Vectord<5> GlobalMeasurementModel(
} // namespace
TEST(ExtendedKalmanFilterTest, Init) {
constexpr auto dt = 0.00505_s;
constexpr units::second_t dt = 5_ms;
frc::ExtendedKalmanFilter<5, 2, 3> observer{Dynamics,
LocalMeasurementModel,
@@ -80,7 +80,7 @@ TEST(ExtendedKalmanFilterTest, Init) {
}
TEST(ExtendedKalmanFilterTest, Convergence) {
constexpr auto dt = 0.00505_s;
constexpr units::second_t dt = 5_ms;
constexpr auto rb = 0.8382_m / 2.0; // Robot radius
frc::ExtendedKalmanFilter<5, 2, 3> observer{Dynamics,

View File

@@ -68,7 +68,7 @@ frc::Vectord<5> DriveGlobalMeasurementModel(
}
TEST(MerweUKFTest, DriveInit) {
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
frc::MerweUKF<5, 2, 3> observer{DriveDynamics,
DriveLocalMeasurementModel,
@@ -94,7 +94,7 @@ TEST(MerweUKFTest, DriveInit) {
}
TEST(MerweUKFTest, DriveConvergence) {
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
constexpr auto rb = 0.8382_m / 2.0; // Robot radius
frc::MerweUKF<5, 2, 3> observer{DriveDynamics,
@@ -206,7 +206,7 @@ TEST(MerweUKFTest, LinearUKF) {
}
TEST(MerweUKFTest, RoundTripP) {
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
frc::MerweUKF<2, 2, 2> observer{
[](const frc::Vectord<2>& x, const frc::Vectord<2>& u) { return x; },

View File

@@ -56,7 +56,7 @@ TEST(DiscretizationTest, DiscretizeSlowModelAQ) {
frc::Matrixd<2, 2> contA{{0, 1}, {0, 0}};
frc::Matrixd<2, 2> contQ{{1, 0}, {0, 1}};
constexpr auto dt = 1_s;
constexpr units::second_t dt = 1_s;
// T
// Q_d ≈ ∫ e^(Aτ) Q e^(Aᵀτ) dτ
@@ -88,7 +88,7 @@ TEST(DiscretizationTest, DiscretizeFastModelAQ) {
frc::Matrixd<2, 2> contA{{0, 1}, {0, -1406.29}};
frc::Matrixd<2, 2> contQ{{0.0025, 0}, {0, 1}};
constexpr auto dt = 5_ms;
constexpr units::second_t dt = 5_ms;
// T
// Q_d = ∫ e^(Aτ) Q e^(Aᵀτ) dτ

View File

@@ -12,23 +12,23 @@ file(GLOB wpiutil_jni_src src/main/native/cpp/jni/WPIUtilJNI.cpp)
if(WITH_JAVA)
include(UseJava)
if(NOT EXISTS "${WPILIB_BINARY_DIR}/wpiutil/thirdparty/jackson/jackson-core-2.15.2.jar")
if(NOT EXISTS "${WPILIB_BINARY_DIR}/wpiutil/thirdparty/jackson/jackson-core-2.19.2.jar")
set(BASE_URL "https://search.maven.org/remotecontent?filepath=")
set(JAR_ROOT "${WPILIB_BINARY_DIR}/wpiutil/thirdparty/jackson")
message(STATUS "Downloading Jackson jarfiles...")
download_and_check(
"${BASE_URL}com/fasterxml/jackson/core/jackson-core/2.15.2/jackson-core-2.15.2.jar"
"${JAR_ROOT}/jackson-core-2.15.2.jar"
"${BASE_URL}com/fasterxml/jackson/core/jackson-core/2.19.2/jackson-core-2.19.2.jar"
"${JAR_ROOT}/jackson-core-2.19.2.jar"
)
download_and_check(
"${BASE_URL}com/fasterxml/jackson/core/jackson-databind/2.15.2/jackson-databind-2.15.2.jar"
"${JAR_ROOT}/jackson-databind-2.15.2.jar"
"${BASE_URL}com/fasterxml/jackson/core/jackson-databind/2.19.2/jackson-databind-2.19.2.jar"
"${JAR_ROOT}/jackson-databind-2.19.2.jar"
)
download_and_check(
"${BASE_URL}com/fasterxml/jackson/core/jackson-annotations/2.15.2/jackson-annotations-2.15.2.jar"
"${JAR_ROOT}/jackson-annotations-2.15.2.jar"
"${BASE_URL}com/fasterxml/jackson/core/jackson-annotations/2.19.2/jackson-annotations-2.19.2.jar"
"${JAR_ROOT}/jackson-annotations-2.19.2.jar"
)
message(STATUS "All files downloaded.")
@@ -36,14 +36,14 @@ if(WITH_JAVA)
file(GLOB JACKSON_JARS ${WPILIB_BINARY_DIR}/wpiutil/thirdparty/jackson/*.jar)
if(NOT EXISTS "${WPILIB_BINARY_DIR}/wpiutil/thirdparty/quickbuf/quickbuf-runtime-1.3.3.jar")
if(NOT EXISTS "${WPILIB_BINARY_DIR}/wpiutil/thirdparty/quickbuf/quickbuf-runtime-1.4.jar")
set(BASE_URL "https://search.maven.org/remotecontent?filepath=")
set(JAR_ROOT "${WPILIB_BINARY_DIR}/wpiutil/thirdparty/quickbuf")
message(STATUS "Downloading Quickbuf jarfile...")
download_and_check(
"${BASE_URL}us/hebi/quickbuf/quickbuf-runtime/1.3.3/quickbuf-runtime-1.3.3.jar"
"${JAR_ROOT}/quickbuf-runtime-1.3.3.jar"
"${BASE_URL}us/hebi/quickbuf/quickbuf-runtime/1.4/quickbuf-runtime-1.4.jar"
"${JAR_ROOT}/quickbuf-runtime-1.4.jar"
)
message(STATUS "Downloaded.")

View File

@@ -275,8 +275,8 @@ model {
}
dependencies {
api "com.fasterxml.jackson.core:jackson-annotations:2.15.2"
api "com.fasterxml.jackson.core:jackson-core:2.15.2"
api "com.fasterxml.jackson.core:jackson-databind:2.15.2"
api 'us.hebi.quickbuf:quickbuf-runtime:1.3.3'
api "com.fasterxml.jackson.core:jackson-annotations:2.19.2"
api "com.fasterxml.jackson.core:jackson-core:2.19.2"
api "com.fasterxml.jackson.core:jackson-databind:2.19.2"
api 'us.hebi.quickbuf:quickbuf-runtime:1.4'
}