Upload a new object detection model to this device that can be used in a pipeline. Naming convention
- should be name-verticalResolution-horizontalResolution-modelType. Additionally, the labels
- file ought to have the same name as the RKNN file, with -labels appended to the end. For
- example, if the RKNN file is named note-640-640-yolov5s.rknn, the labels file should be
- named note-640-640-yolov5s-labels.txt. Note that ONLY 640x640 YOLOv5 & YOLOv8 models
- trained and converted to `.rknn` format for RK3588 CPUs are currently supported!
+ should be name-verticalResolution-horizontalResolution-modelType. The
+ name should only include alphanumeric characters, periods, and underscores. Additionally,
+ the labels file ought to have the same name as the RKNN file, with -labels appended to the
+ end. For example, if the RKNN file is named note-640-640-yolov5s.rknn, the labels file
+ should be named note-640-640-yolov5s-labels.txt. Note that ONLY 640x640 YOLOv5 & YOLOv8
+ models trained and converted to `.rknn` format for RK3588 CPUs are currently supported!
diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java
index 32d082241..966098977 100644
--- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java
+++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java
@@ -33,6 +33,8 @@ import java.util.Map;
import java.util.Optional;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
import org.photonvision.common.hardware.Platform;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
@@ -301,4 +303,66 @@ public class NeuralNetworkModelManager {
logger.error("Error extracting models", e);
}
}
+
+ private static Pattern modelPattern =
+ Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)\\.rknn$");
+
+ private static Pattern labelsPattern =
+ Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)-labels\\.txt$");
+
+ /**
+ * Check naming conventions for models and labels.
+ *
+ * This is static as it is not dependent on the state of the class.
+ *
+ * @param modelName the name of the model
+ * @param labelsName the name of the labels file
+ * @throws IllegalArgumentException if the names are invalid
+ */
+ public static void verifyRKNNNames(String modelName, String labelsName) {
+ // check null
+ if (modelName == null || labelsName == null) {
+ throw new IllegalArgumentException("Model name and labels name cannot be null");
+ }
+
+ // These patterns check that the naming convention of
+ // name-widthResolution-heightResolution-modelType is followed
+
+ Matcher modelMatcher = modelPattern.matcher(modelName);
+ Matcher labelsMatcher = labelsPattern.matcher(labelsName);
+
+ if (!modelMatcher.matches() || !labelsMatcher.matches()) {
+ throw new IllegalArgumentException(
+ "Model name and labels name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn and name-widthResolution-heightResolution-modelType-labels.txt");
+ }
+
+ if (!modelMatcher.group(1).equals(labelsMatcher.group(1))
+ || !modelMatcher.group(2).equals(labelsMatcher.group(2))
+ || !modelMatcher.group(3).equals(labelsMatcher.group(3))
+ || !modelMatcher.group(4).equals(labelsMatcher.group(4))) {
+ throw new IllegalArgumentException("Model name and labels name must be matching.");
+ }
+ }
+
+ /**
+ * Parse RKNN name and return the name, width, height, and model type.
+ *
+ *
This is static as it is not dependent on the state of the class.
+ *
+ * @param modelName the name of the model
+ * @throws IllegalArgumentException if the model name does not follow the naming convention
+ * @return an array containing the name, width, height, and model type
+ */
+ public static String[] parseRKNNName(String modelName) {
+ Matcher modelMatcher = modelPattern.matcher(modelName);
+
+ if (!modelMatcher.matches()) {
+ throw new IllegalArgumentException(
+ "Model name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn");
+ }
+
+ return new String[] {
+ modelMatcher.group(1), modelMatcher.group(2), modelMatcher.group(3), modelMatcher.group(4)
+ };
+ }
}
diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java b/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java
index 0703e3625..038c6695a 100644
--- a/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java
+++ b/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java
@@ -23,6 +23,7 @@ import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import org.opencv.core.Size;
+import org.photonvision.common.configuration.NeuralNetworkModelManager;
import org.photonvision.jni.RknnObjectDetector;
import org.photonvision.rknn.RknnJNI;
@@ -67,10 +68,8 @@ public class RknnModel implements Model {
public RknnModel(File modelFile, String labels) throws IllegalArgumentException, IOException {
this.modelFile = modelFile;
- String[] parts = modelFile.getName().split("-");
- if (parts.length != 4) {
- throw new IllegalArgumentException("Invalid model file name: " + modelFile);
- }
+ // parseRKNNName throws an IllegalArgumentException if the model name is invalid
+ String[] parts = NeuralNetworkModelManager.parseRKNNName(modelFile.getName());
this.version = getModelVersion(parts[3]);
diff --git a/photon-core/src/test/java/org/photonvision/vision/pipeline/ObjectDetectionTest.java b/photon-core/src/test/java/org/photonvision/vision/pipeline/ObjectDetectionTest.java
new file mode 100644
index 000000000..c389e87d6
--- /dev/null
+++ b/photon-core/src/test/java/org/photonvision/vision/pipeline/ObjectDetectionTest.java
@@ -0,0 +1,134 @@
+/*
+ * Copyright (C) Photon Vision.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package org.photonvision.vision.pipeline;
+
+import static org.junit.jupiter.api.Assertions.assertArrayEquals;
+import static org.junit.jupiter.api.Assertions.assertThrows;
+
+import java.util.LinkedList;
+import java.util.stream.Stream;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.Arguments;
+import org.junit.jupiter.params.provider.MethodSource;
+import org.photonvision.common.configuration.NeuralNetworkModelManager;
+
+public class ObjectDetectionTest {
+ private static LinkedList passNames =
+ new LinkedList(
+ java.util.Arrays.asList(
+ new String[] {"note-640-640-yolov5s.rknn", "note-640-640-yolov5s-labels.txt"},
+ new String[] {"object-640-640-yolov8n.rknn", "object-640-640-yolov8n-labels.txt"},
+ new String[] {
+ "example_1.2-640-640-yolov5l.rknn", "example_1.2-640-640-yolov5l-labels.txt"
+ },
+ new String[] {"demo_3.5-640-640-yolov8m.rknn", "demo_3.5-640-640-yolov8m-labels.txt"},
+ new String[] {"sample-640-640-yolov5x.rknn", "sample-640-640-yolov5x-labels.txt"},
+ new String[] {
+ "test_case-640-640-yolov8s.rknn", "test_case-640-640-yolov8s-labels.txt"
+ },
+ new String[] {
+ "model_ABC-640-640-yolov5n.rknn", "model_ABC-640-640-yolov5n-labels.txt"
+ },
+ new String[] {"my_model-640-640-yolov8x.rknn", "my_model-640-640-yolov8x-labels.txt"},
+ new String[] {"name_1.0-640-640-yolov5n.rknn", "name_1.0-640-640-yolov5n-labels.txt"},
+ new String[] {
+ "valid_name-640-640-yolov8s.rknn", "valid_name-640-640-yolov8s-labels.txt"
+ },
+ new String[] {
+ "test.model-640-640-yolov5l.rknn", "test.model-640-640-yolov5l-labels.txt"
+ },
+ new String[] {
+ "case1_test-640-640-yolov8m.rknn", "case1_test-640-640-yolov8m-labels.txt"
+ },
+ new String[] {"A123-640-640-yolov5x.rknn", "A123-640-640-yolov5x-labels.txt"},
+ new String[] {
+ "z_y_test.model-640-640-yolov8n.rknn", "z_y_test.model-640-640-yolov8n-labels.txt"
+ }));
+ private static LinkedList parsedPassNames =
+ new LinkedList(
+ java.util.Arrays.asList(
+ new String[] {"note", "640", "640", "yolov5s"},
+ new String[] {"object", "640", "640", "yolov8n"},
+ new String[] {"example_1.2", "640", "640", "yolov5l"},
+ new String[] {"demo_3.5", "640", "640", "yolov8m"},
+ new String[] {"sample", "640", "640", "yolov5x"},
+ new String[] {"test_case", "640", "640", "yolov8s"},
+ new String[] {"model_ABC", "640", "640", "yolov5n"},
+ new String[] {"my_model", "640", "640", "yolov8x"},
+ new String[] {"name_1.0", "640", "640", "yolov5n"},
+ new String[] {"valid_name", "640", "640", "yolov8s"},
+ new String[] {"test.model", "640", "640", "yolov5l"},
+ new String[] {"case1_test", "640", "640", "yolov8m"},
+ new String[] {"A123", "640", "640", "yolov5x"},
+ new String[] {"z_y_test.model", "640", "640", "yolov8n"}));
+ private static LinkedList failNames =
+ new LinkedList(
+ java.util.Arrays.asList(
+ new String[] {"note-yolov5s.rknn", "note-640-640-yolov5s-labels.txt"},
+ new String[] {"640-640-yolov8n.rknn", "object-640-640-yolov8n-labels.txt"},
+ new String[] {"example_1.2.rknn", "example_1.2-640-640-yolov5l-labels.txt"},
+ new String[] {"demo_3.5-640-yolov8m.rknn", "demo_3.5-640-640-yolov8m-labels.txt"},
+ new String[] {"sample-640.rknn", "sample-640-640-yolov5x-labels.txt"},
+ new String[] {"test_case.txt", "test_case-640-640-yolov8s-labels.txt"},
+ new String[] {"model_ABC.onnx", "model_ABC-640-640-yolov5n-labels.txt"},
+ new String[] {"my_model", "my_model-640-640-yolov8x-labels.txt"},
+ new String[] {"name_1.0-yolov5n.rknn", "wrong-labels.txt"},
+ new String[] {"", "valid_name-640-640-yolov8s-labels.txt"},
+ new String[] {null, "test.model-640-640-yolov5l-labels.txt"},
+ new String[] {"case1_test-640-640-yolov8m.rknn", null},
+ new String[] {"A123-640-640.rknn", "different-labels.txt"},
+ new String[] {"z_y_test.model", ""}));
+
+ // Test the model name validation for names that ought to pass
+ @ParameterizedTest
+ @MethodSource("verifyPassNameProvider")
+ public void testRKNNVerificationPass(String[] names) {
+ NeuralNetworkModelManager.verifyRKNNNames(names[0], names[1]);
+ }
+
+ // // Test the model name validation for names that ought to fail
+ @ParameterizedTest
+ @MethodSource("verifyFailNameProvider")
+ public void testRNNVerificationFail(String[] names) {
+ assertThrows(
+ IllegalArgumentException.class,
+ () -> NeuralNetworkModelManager.verifyRKNNNames(names[0], names[1]));
+ }
+
+ // Test the model name parsing
+ @ParameterizedTest
+ @MethodSource("parseNameProvider")
+ public void testRKNNNameParsing(String[] expected, String name) {
+ String[] parsed = NeuralNetworkModelManager.parseRKNNName(name);
+ assertArrayEquals(expected, parsed);
+ }
+
+ static Stream verifyPassNameProvider() {
+ return passNames.stream().map(array -> Arguments.of((Object) array));
+ }
+
+ static Stream verifyFailNameProvider() {
+ return failNames.stream().map(array -> Arguments.of((Object) array));
+ }
+
+ static Stream parseNameProvider() {
+ // return a stream of parsed pass names, and the first element of each pass name
+ return passNames.stream()
+ .map(name -> Arguments.of(parsedPassNames.get(passNames.indexOf(name)), name[0]));
+ }
+}
diff --git a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java
index afe8518c9..6becc0000 100644
--- a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java
+++ b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java
@@ -29,7 +29,6 @@ import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Optional;
-import java.util.regex.Pattern;
import javax.imageio.ImageIO;
import org.apache.commons.io.FileUtils;
import org.opencv.core.Mat;
@@ -573,25 +572,9 @@ public class RequestHandler {
}
// verify naming convention
- // this check will need to be modified if different model types are added
- Pattern modelPattern =
- Pattern.compile("^[a-zA-Z0-9]+-\\d+-\\d+-yolov(?:5|8|11)[a-z]*\\.rknn$");
-
- Pattern labelsPattern =
- Pattern.compile("^[a-zA-Z0-9]+-\\d+-\\d+-yolov(?:5|8|11)[a-z]*-labels\\.txt$");
-
- if (!modelPattern.matcher(modelFile.filename()).matches()
- || !labelsPattern.matcher(labelsFile.filename()).matches()
- || !(modelFile
- .filename()
- .substring(0, modelFile.filename().indexOf("-"))
- .equals(labelsFile.filename().substring(0, labelsFile.filename().indexOf("-"))))) {
- ctx.status(400);
- ctx.result("The uploaded files were not named correctly.");
- logger.error("The uploaded object detection model files were not named correctly.");
- return;
- }
+ // throws IllegalArgumentException if the model name is invalid
+ NeuralNetworkModelManager.verifyRKNNNames(modelFile.filename(), labelsFile.filename());
// TODO move into neural network manager