Update Allowed Naming Conventions For Object Detection Models (#1749)

This commit is contained in:
Sam Freund
2025-02-09 09:12:47 -06:00
committed by GitHub
parent 7067c75525
commit 00fb4bdf07
6 changed files with 218 additions and 36 deletions

View File

@@ -33,6 +33,8 @@ import java.util.Map;
import java.util.Optional;
import java.util.jar.JarEntry;
import java.util.jar.JarFile;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.photonvision.common.hardware.Platform;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
@@ -301,4 +303,66 @@ public class NeuralNetworkModelManager {
logger.error("Error extracting models", e);
}
}
private static Pattern modelPattern =
Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)\\.rknn$");
private static Pattern labelsPattern =
Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)-labels\\.txt$");
/**
* Check naming conventions for models and labels.
*
* <p>This is static as it is not dependent on the state of the class.
*
* @param modelName the name of the model
* @param labelsName the name of the labels file
* @throws IllegalArgumentException if the names are invalid
*/
public static void verifyRKNNNames(String modelName, String labelsName) {
// check null
if (modelName == null || labelsName == null) {
throw new IllegalArgumentException("Model name and labels name cannot be null");
}
// These patterns check that the naming convention of
// name-widthResolution-heightResolution-modelType is followed
Matcher modelMatcher = modelPattern.matcher(modelName);
Matcher labelsMatcher = labelsPattern.matcher(labelsName);
if (!modelMatcher.matches() || !labelsMatcher.matches()) {
throw new IllegalArgumentException(
"Model name and labels name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn and name-widthResolution-heightResolution-modelType-labels.txt");
}
if (!modelMatcher.group(1).equals(labelsMatcher.group(1))
|| !modelMatcher.group(2).equals(labelsMatcher.group(2))
|| !modelMatcher.group(3).equals(labelsMatcher.group(3))
|| !modelMatcher.group(4).equals(labelsMatcher.group(4))) {
throw new IllegalArgumentException("Model name and labels name must be matching.");
}
}
/**
* Parse RKNN name and return the name, width, height, and model type.
*
* <p>This is static as it is not dependent on the state of the class.
*
* @param modelName the name of the model
* @throws IllegalArgumentException if the model name does not follow the naming convention
* @return an array containing the name, width, height, and model type
*/
public static String[] parseRKNNName(String modelName) {
Matcher modelMatcher = modelPattern.matcher(modelName);
if (!modelMatcher.matches()) {
throw new IllegalArgumentException(
"Model name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn");
}
return new String[] {
modelMatcher.group(1), modelMatcher.group(2), modelMatcher.group(3), modelMatcher.group(4)
};
}
}

View File

@@ -23,6 +23,7 @@ import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.List;
import org.opencv.core.Size;
import org.photonvision.common.configuration.NeuralNetworkModelManager;
import org.photonvision.jni.RknnObjectDetector;
import org.photonvision.rknn.RknnJNI;
@@ -67,10 +68,8 @@ public class RknnModel implements Model {
public RknnModel(File modelFile, String labels) throws IllegalArgumentException, IOException {
this.modelFile = modelFile;
String[] parts = modelFile.getName().split("-");
if (parts.length != 4) {
throw new IllegalArgumentException("Invalid model file name: " + modelFile);
}
// parseRKNNName throws an IllegalArgumentException if the model name is invalid
String[] parts = NeuralNetworkModelManager.parseRKNNName(modelFile.getName());
this.version = getModelVersion(parts[3]);

View File

@@ -0,0 +1,134 @@
/*
* Copyright (C) Photon Vision.
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
package org.photonvision.vision.pipeline;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import java.util.LinkedList;
import java.util.stream.Stream;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;
import org.photonvision.common.configuration.NeuralNetworkModelManager;
public class ObjectDetectionTest {
private static LinkedList<String[]> passNames =
new LinkedList<String[]>(
java.util.Arrays.asList(
new String[] {"note-640-640-yolov5s.rknn", "note-640-640-yolov5s-labels.txt"},
new String[] {"object-640-640-yolov8n.rknn", "object-640-640-yolov8n-labels.txt"},
new String[] {
"example_1.2-640-640-yolov5l.rknn", "example_1.2-640-640-yolov5l-labels.txt"
},
new String[] {"demo_3.5-640-640-yolov8m.rknn", "demo_3.5-640-640-yolov8m-labels.txt"},
new String[] {"sample-640-640-yolov5x.rknn", "sample-640-640-yolov5x-labels.txt"},
new String[] {
"test_case-640-640-yolov8s.rknn", "test_case-640-640-yolov8s-labels.txt"
},
new String[] {
"model_ABC-640-640-yolov5n.rknn", "model_ABC-640-640-yolov5n-labels.txt"
},
new String[] {"my_model-640-640-yolov8x.rknn", "my_model-640-640-yolov8x-labels.txt"},
new String[] {"name_1.0-640-640-yolov5n.rknn", "name_1.0-640-640-yolov5n-labels.txt"},
new String[] {
"valid_name-640-640-yolov8s.rknn", "valid_name-640-640-yolov8s-labels.txt"
},
new String[] {
"test.model-640-640-yolov5l.rknn", "test.model-640-640-yolov5l-labels.txt"
},
new String[] {
"case1_test-640-640-yolov8m.rknn", "case1_test-640-640-yolov8m-labels.txt"
},
new String[] {"A123-640-640-yolov5x.rknn", "A123-640-640-yolov5x-labels.txt"},
new String[] {
"z_y_test.model-640-640-yolov8n.rknn", "z_y_test.model-640-640-yolov8n-labels.txt"
}));
private static LinkedList<String[]> parsedPassNames =
new LinkedList<String[]>(
java.util.Arrays.asList(
new String[] {"note", "640", "640", "yolov5s"},
new String[] {"object", "640", "640", "yolov8n"},
new String[] {"example_1.2", "640", "640", "yolov5l"},
new String[] {"demo_3.5", "640", "640", "yolov8m"},
new String[] {"sample", "640", "640", "yolov5x"},
new String[] {"test_case", "640", "640", "yolov8s"},
new String[] {"model_ABC", "640", "640", "yolov5n"},
new String[] {"my_model", "640", "640", "yolov8x"},
new String[] {"name_1.0", "640", "640", "yolov5n"},
new String[] {"valid_name", "640", "640", "yolov8s"},
new String[] {"test.model", "640", "640", "yolov5l"},
new String[] {"case1_test", "640", "640", "yolov8m"},
new String[] {"A123", "640", "640", "yolov5x"},
new String[] {"z_y_test.model", "640", "640", "yolov8n"}));
private static LinkedList<String[]> failNames =
new LinkedList<String[]>(
java.util.Arrays.asList(
new String[] {"note-yolov5s.rknn", "note-640-640-yolov5s-labels.txt"},
new String[] {"640-640-yolov8n.rknn", "object-640-640-yolov8n-labels.txt"},
new String[] {"example_1.2.rknn", "example_1.2-640-640-yolov5l-labels.txt"},
new String[] {"demo_3.5-640-yolov8m.rknn", "demo_3.5-640-640-yolov8m-labels.txt"},
new String[] {"sample-640.rknn", "sample-640-640-yolov5x-labels.txt"},
new String[] {"test_case.txt", "test_case-640-640-yolov8s-labels.txt"},
new String[] {"model_ABC.onnx", "model_ABC-640-640-yolov5n-labels.txt"},
new String[] {"my_model", "my_model-640-640-yolov8x-labels.txt"},
new String[] {"name_1.0-yolov5n.rknn", "wrong-labels.txt"},
new String[] {"", "valid_name-640-640-yolov8s-labels.txt"},
new String[] {null, "test.model-640-640-yolov5l-labels.txt"},
new String[] {"case1_test-640-640-yolov8m.rknn", null},
new String[] {"A123-640-640.rknn", "different-labels.txt"},
new String[] {"z_y_test.model", ""}));
// Test the model name validation for names that ought to pass
@ParameterizedTest
@MethodSource("verifyPassNameProvider")
public void testRKNNVerificationPass(String[] names) {
NeuralNetworkModelManager.verifyRKNNNames(names[0], names[1]);
}
// // Test the model name validation for names that ought to fail
@ParameterizedTest
@MethodSource("verifyFailNameProvider")
public void testRNNVerificationFail(String[] names) {
assertThrows(
IllegalArgumentException.class,
() -> NeuralNetworkModelManager.verifyRKNNNames(names[0], names[1]));
}
// Test the model name parsing
@ParameterizedTest
@MethodSource("parseNameProvider")
public void testRKNNNameParsing(String[] expected, String name) {
String[] parsed = NeuralNetworkModelManager.parseRKNNName(name);
assertArrayEquals(expected, parsed);
}
static Stream<Arguments> verifyPassNameProvider() {
return passNames.stream().map(array -> Arguments.of((Object) array));
}
static Stream<Arguments> verifyFailNameProvider() {
return failNames.stream().map(array -> Arguments.of((Object) array));
}
static Stream<Arguments> parseNameProvider() {
// return a stream of parsed pass names, and the first element of each pass name
return passNames.stream()
.map(name -> Arguments.of(parsedPassNames.get(passNames.indexOf(name)), name[0]));
}
}