Refactor Rubik OD code to generic TFLite OD code (#2516)

This commit is contained in:
Jade
2026-06-30 01:02:47 +08:00
committed by GitHub
parent bd8fa28ab7
commit 8f560e5b1f
7 changed files with 37 additions and 29 deletions

View File

@@ -38,9 +38,10 @@ import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelPr
import org.photonvision.common.hardware.Platform;
import org.photonvision.common.logging.LogGroup;
import org.photonvision.common.logging.Logger;
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
import org.photonvision.vision.objects.Model;
import org.photonvision.vision.objects.RknnModel;
import org.photonvision.vision.objects.RubikModel;
import org.photonvision.vision.objects.TFLiteModel;
/**
* Manages the loading of neural network models.
@@ -360,7 +361,7 @@ public class NeuralNetworkModelManager {
models.get(properties.family()).add(new RknnModel(properties));
}
case RUBIK -> {
models.get(properties.family()).add(new RubikModel(properties));
models.get(properties.family()).add(new TFLiteModel(properties, TFLiteSource.RUBIK));
}
}
logger.info(

View File

@@ -19,7 +19,6 @@ package org.photonvision.vision.objects;
import java.io.File;
import java.nio.file.Path;
import org.opencv.core.Size;
import org.photonvision.common.configuration.NeuralNetworkModelManager.Family;
import org.photonvision.common.configuration.NeuralNetworkModelManager.Version;
import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelProperties;
@@ -79,8 +78,7 @@ public class RknnModel implements Model {
}
public ObjectDetector load() {
return new RknnObjectDetector(
this, new Size(properties.resolutionWidth(), properties.resolutionHeight()));
return new RknnObjectDetector(this);
}
public String toString() {

View File

@@ -61,9 +61,10 @@ public class RknnObjectDetector implements ObjectDetector {
* @param inputSize The required image dimensions for the model. Images will be {@link
* Letterbox}ed to this shape.
*/
public RknnObjectDetector(RknnModel model, Size inputSize) {
public RknnObjectDetector(RknnModel model) {
this.model = model;
this.inputSize = inputSize;
this.inputSize =
new Size(model.properties.resolutionWidth(), model.properties.resolutionHeight());
// Create the detector
objPointer =

View File

@@ -19,22 +19,26 @@ package org.photonvision.vision.objects;
import java.io.File;
import java.nio.file.Path;
import org.opencv.core.Size;
import org.photonvision.common.configuration.NeuralNetworkModelManager.Family;
import org.photonvision.common.configuration.NeuralNetworkModelManager.Version;
import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelProperties;
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
public class RubikModel implements Model {
public class TFLiteModel implements Model {
public final File modelFile;
public final ModelProperties properties;
public final TFLiteSource backend;
/**
* Rubik model constructor.
* TFLite model constructor.
*
* @param properties The properties of the model.
* @param backend The backend of the model should run on.
* @throws IllegalArgumentException
*/
public RubikModel(ModelProperties properties) throws IllegalArgumentException {
public TFLiteModel(ModelProperties properties, TFLiteSource backend)
throws IllegalArgumentException {
this.backend = backend;
modelFile = new File(properties.modelPath().toString());
if (!modelFile.exists()) {
throw new IllegalArgumentException("Model file does not exist: " + modelFile);
@@ -77,11 +81,10 @@ public class RubikModel implements Model {
}
public ObjectDetector load() {
return new RubikObjectDetector(
this, new Size(this.properties.resolutionWidth(), this.properties.resolutionHeight()));
return new TFLiteObjectDetector(this, backend);
}
public String toString() {
return "RubikModel{" + "modelFile=" + modelFile + ", properties=" + properties + '}';
return "TFLiteModel{" + "modelFile=" + modelFile + ", properties=" + properties + '}';
}
}

View File

@@ -30,9 +30,9 @@ import org.photonvision.tflite.TFLiteJNI;
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
import org.photonvision.vision.pipe.impl.NeuralNetworkPipeResult;
/** Manages an object detector using the rubik backend. */
public class RubikObjectDetector implements ObjectDetector {
private static final Logger logger = new Logger(RubikObjectDetector.class, LogGroup.General);
/** Manages an object detector using the TFLite backend. */
public class TFLiteObjectDetector implements ObjectDetector {
private static final Logger logger = new Logger(TFLiteObjectDetector.class, LogGroup.General);
private static final Cleaner cleaner = Cleaner.create();
@@ -45,26 +45,26 @@ public class RubikObjectDetector implements ObjectDetector {
/** Pointer to the native object */
private final long ptr;
private final RubikModel model;
private final TFLiteModel model;
private final Size inputSize;
/** Returns the model in use by this detector. */
@Override
public RubikModel getModel() {
public TFLiteModel getModel() {
return model;
}
/**
* Creates a new rubikObjectDetector from the given model.
* Creates a new TFLite detector from the given model.
*
* @param model The model to create the detector from.
* @param inputSize The required image dimensions for the model. Images will be {@link
* Letterbox}ed to this shape.
* @param source The backend to run the detector on.
*/
public RubikObjectDetector(RubikModel model, Size inputSize) {
public TFLiteObjectDetector(TFLiteModel model, TFLiteSource backend) {
this.model = model;
this.inputSize = inputSize;
this.inputSize =
new Size(model.properties.resolutionWidth(), model.properties.resolutionHeight());
// Create the detector
try {
@@ -72,7 +72,7 @@ public class RubikObjectDetector implements ObjectDetector {
TFLiteJNI.create(
model.modelFile.getPath().toString(),
model.properties.version().ordinal(),
TFLiteSource.RUBIK.value());
backend.value());
} catch (Exception e) {
logger.error("Failed to create detector from path " + model.modelFile.getPath(), e);
throw new RuntimeException(
@@ -83,7 +83,7 @@ public class RubikObjectDetector implements ObjectDetector {
logger.error(
"Failed to create detector from path "
+ model.modelFile.getPath()
+ ". Please ensure the model is valid and compatible with the Rubik backend.");
+ ". Please ensure the model is valid and compatible with the TFLite backend.");
throw new RuntimeException(
"Failed to create detector from path " + model.modelFile.getPath());
} else if (!TFLiteJNI.isQuantized(ptr)) {
@@ -107,7 +107,7 @@ public class RubikObjectDetector implements ObjectDetector {
}
/**
* Detects objects in the given input image using the rubikDetector.
* Detects objects in the given input image using the TFLite detector.
*
* @param in The input image to perform object detection on.
* @param nmsThresh The threshold value for non-maximum suppression.

View File

@@ -17,6 +17,10 @@ dependencies {
// Needed for Javalin Runtime Logging
implementation "org.slf4j:slf4j-simple:2.0.7"
implementation("org.photonvision:tflite_jni-java:$tfliteVersion") {
transitive = false
}
}
group = 'org.photonvision'

View File

@@ -55,12 +55,13 @@ import org.photonvision.common.networking.NetworkManager;
import org.photonvision.common.util.ShellExec;
import org.photonvision.common.util.TimedTaskManager;
import org.photonvision.common.util.file.ProgramDirectoryUtilities;
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
import org.photonvision.vision.calibration.CameraCalibrationCoefficients;
import org.photonvision.vision.camera.CameraQuirk;
import org.photonvision.vision.camera.PVCameraInfo;
import org.photonvision.vision.objects.ObjectDetector;
import org.photonvision.vision.objects.RknnModel;
import org.photonvision.vision.objects.RubikModel;
import org.photonvision.vision.objects.TFLiteModel;
import org.photonvision.vision.processes.VisionSourceManager;
import org.zeroturnaround.zip.ZipUtil;
@@ -679,7 +680,7 @@ public class RequestHandler {
try {
objDetector =
switch (family) {
case RUBIK -> new RubikModel(modelProperties).load();
case RUBIK -> new TFLiteModel(modelProperties, TFLiteSource.RUBIK).load();
case RKNN -> new RknnModel(modelProperties).load();
};
} catch (RuntimeException e) {