mirror of
https://github.com/PhotonVision/photonvision
synced 2026-06-30 02:31:40 +00:00
Refactor Rubik OD code to generic TFLite OD code (#2516)
This commit is contained in:
@@ -38,9 +38,10 @@ import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelPr
|
||||
import org.photonvision.common.hardware.Platform;
|
||||
import org.photonvision.common.logging.LogGroup;
|
||||
import org.photonvision.common.logging.Logger;
|
||||
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
|
||||
import org.photonvision.vision.objects.Model;
|
||||
import org.photonvision.vision.objects.RknnModel;
|
||||
import org.photonvision.vision.objects.RubikModel;
|
||||
import org.photonvision.vision.objects.TFLiteModel;
|
||||
|
||||
/**
|
||||
* Manages the loading of neural network models.
|
||||
@@ -360,7 +361,7 @@ public class NeuralNetworkModelManager {
|
||||
models.get(properties.family()).add(new RknnModel(properties));
|
||||
}
|
||||
case RUBIK -> {
|
||||
models.get(properties.family()).add(new RubikModel(properties));
|
||||
models.get(properties.family()).add(new TFLiteModel(properties, TFLiteSource.RUBIK));
|
||||
}
|
||||
}
|
||||
logger.info(
|
||||
|
||||
@@ -19,7 +19,6 @@ package org.photonvision.vision.objects;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
import org.opencv.core.Size;
|
||||
import org.photonvision.common.configuration.NeuralNetworkModelManager.Family;
|
||||
import org.photonvision.common.configuration.NeuralNetworkModelManager.Version;
|
||||
import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelProperties;
|
||||
@@ -79,8 +78,7 @@ public class RknnModel implements Model {
|
||||
}
|
||||
|
||||
public ObjectDetector load() {
|
||||
return new RknnObjectDetector(
|
||||
this, new Size(properties.resolutionWidth(), properties.resolutionHeight()));
|
||||
return new RknnObjectDetector(this);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
|
||||
@@ -61,9 +61,10 @@ public class RknnObjectDetector implements ObjectDetector {
|
||||
* @param inputSize The required image dimensions for the model. Images will be {@link
|
||||
* Letterbox}ed to this shape.
|
||||
*/
|
||||
public RknnObjectDetector(RknnModel model, Size inputSize) {
|
||||
public RknnObjectDetector(RknnModel model) {
|
||||
this.model = model;
|
||||
this.inputSize = inputSize;
|
||||
this.inputSize =
|
||||
new Size(model.properties.resolutionWidth(), model.properties.resolutionHeight());
|
||||
|
||||
// Create the detector
|
||||
objPointer =
|
||||
|
||||
@@ -19,22 +19,26 @@ package org.photonvision.vision.objects;
|
||||
|
||||
import java.io.File;
|
||||
import java.nio.file.Path;
|
||||
import org.opencv.core.Size;
|
||||
import org.photonvision.common.configuration.NeuralNetworkModelManager.Family;
|
||||
import org.photonvision.common.configuration.NeuralNetworkModelManager.Version;
|
||||
import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelProperties;
|
||||
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
|
||||
|
||||
public class RubikModel implements Model {
|
||||
public class TFLiteModel implements Model {
|
||||
public final File modelFile;
|
||||
public final ModelProperties properties;
|
||||
public final TFLiteSource backend;
|
||||
|
||||
/**
|
||||
* Rubik model constructor.
|
||||
* TFLite model constructor.
|
||||
*
|
||||
* @param properties The properties of the model.
|
||||
* @param backend The backend of the model should run on.
|
||||
* @throws IllegalArgumentException
|
||||
*/
|
||||
public RubikModel(ModelProperties properties) throws IllegalArgumentException {
|
||||
public TFLiteModel(ModelProperties properties, TFLiteSource backend)
|
||||
throws IllegalArgumentException {
|
||||
this.backend = backend;
|
||||
modelFile = new File(properties.modelPath().toString());
|
||||
if (!modelFile.exists()) {
|
||||
throw new IllegalArgumentException("Model file does not exist: " + modelFile);
|
||||
@@ -77,11 +81,10 @@ public class RubikModel implements Model {
|
||||
}
|
||||
|
||||
public ObjectDetector load() {
|
||||
return new RubikObjectDetector(
|
||||
this, new Size(this.properties.resolutionWidth(), this.properties.resolutionHeight()));
|
||||
return new TFLiteObjectDetector(this, backend);
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return "RubikModel{" + "modelFile=" + modelFile + ", properties=" + properties + '}';
|
||||
return "TFLiteModel{" + "modelFile=" + modelFile + ", properties=" + properties + '}';
|
||||
}
|
||||
}
|
||||
@@ -30,9 +30,9 @@ import org.photonvision.tflite.TFLiteJNI;
|
||||
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
|
||||
import org.photonvision.vision.pipe.impl.NeuralNetworkPipeResult;
|
||||
|
||||
/** Manages an object detector using the rubik backend. */
|
||||
public class RubikObjectDetector implements ObjectDetector {
|
||||
private static final Logger logger = new Logger(RubikObjectDetector.class, LogGroup.General);
|
||||
/** Manages an object detector using the TFLite backend. */
|
||||
public class TFLiteObjectDetector implements ObjectDetector {
|
||||
private static final Logger logger = new Logger(TFLiteObjectDetector.class, LogGroup.General);
|
||||
|
||||
private static final Cleaner cleaner = Cleaner.create();
|
||||
|
||||
@@ -45,26 +45,26 @@ public class RubikObjectDetector implements ObjectDetector {
|
||||
/** Pointer to the native object */
|
||||
private final long ptr;
|
||||
|
||||
private final RubikModel model;
|
||||
private final TFLiteModel model;
|
||||
|
||||
private final Size inputSize;
|
||||
|
||||
/** Returns the model in use by this detector. */
|
||||
@Override
|
||||
public RubikModel getModel() {
|
||||
public TFLiteModel getModel() {
|
||||
return model;
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a new rubikObjectDetector from the given model.
|
||||
* Creates a new TFLite detector from the given model.
|
||||
*
|
||||
* @param model The model to create the detector from.
|
||||
* @param inputSize The required image dimensions for the model. Images will be {@link
|
||||
* Letterbox}ed to this shape.
|
||||
* @param source The backend to run the detector on.
|
||||
*/
|
||||
public RubikObjectDetector(RubikModel model, Size inputSize) {
|
||||
public TFLiteObjectDetector(TFLiteModel model, TFLiteSource backend) {
|
||||
this.model = model;
|
||||
this.inputSize = inputSize;
|
||||
this.inputSize =
|
||||
new Size(model.properties.resolutionWidth(), model.properties.resolutionHeight());
|
||||
|
||||
// Create the detector
|
||||
try {
|
||||
@@ -72,7 +72,7 @@ public class RubikObjectDetector implements ObjectDetector {
|
||||
TFLiteJNI.create(
|
||||
model.modelFile.getPath().toString(),
|
||||
model.properties.version().ordinal(),
|
||||
TFLiteSource.RUBIK.value());
|
||||
backend.value());
|
||||
} catch (Exception e) {
|
||||
logger.error("Failed to create detector from path " + model.modelFile.getPath(), e);
|
||||
throw new RuntimeException(
|
||||
@@ -83,7 +83,7 @@ public class RubikObjectDetector implements ObjectDetector {
|
||||
logger.error(
|
||||
"Failed to create detector from path "
|
||||
+ model.modelFile.getPath()
|
||||
+ ". Please ensure the model is valid and compatible with the Rubik backend.");
|
||||
+ ". Please ensure the model is valid and compatible with the TFLite backend.");
|
||||
throw new RuntimeException(
|
||||
"Failed to create detector from path " + model.modelFile.getPath());
|
||||
} else if (!TFLiteJNI.isQuantized(ptr)) {
|
||||
@@ -107,7 +107,7 @@ public class RubikObjectDetector implements ObjectDetector {
|
||||
}
|
||||
|
||||
/**
|
||||
* Detects objects in the given input image using the rubikDetector.
|
||||
* Detects objects in the given input image using the TFLite detector.
|
||||
*
|
||||
* @param in The input image to perform object detection on.
|
||||
* @param nmsThresh The threshold value for non-maximum suppression.
|
||||
@@ -17,6 +17,10 @@ dependencies {
|
||||
|
||||
// Needed for Javalin Runtime Logging
|
||||
implementation "org.slf4j:slf4j-simple:2.0.7"
|
||||
|
||||
implementation("org.photonvision:tflite_jni-java:$tfliteVersion") {
|
||||
transitive = false
|
||||
}
|
||||
}
|
||||
|
||||
group = 'org.photonvision'
|
||||
|
||||
@@ -55,12 +55,13 @@ import org.photonvision.common.networking.NetworkManager;
|
||||
import org.photonvision.common.util.ShellExec;
|
||||
import org.photonvision.common.util.TimedTaskManager;
|
||||
import org.photonvision.common.util.file.ProgramDirectoryUtilities;
|
||||
import org.photonvision.tflite.TFLiteJNI.TFLiteSource;
|
||||
import org.photonvision.vision.calibration.CameraCalibrationCoefficients;
|
||||
import org.photonvision.vision.camera.CameraQuirk;
|
||||
import org.photonvision.vision.camera.PVCameraInfo;
|
||||
import org.photonvision.vision.objects.ObjectDetector;
|
||||
import org.photonvision.vision.objects.RknnModel;
|
||||
import org.photonvision.vision.objects.RubikModel;
|
||||
import org.photonvision.vision.objects.TFLiteModel;
|
||||
import org.photonvision.vision.processes.VisionSourceManager;
|
||||
import org.zeroturnaround.zip.ZipUtil;
|
||||
|
||||
@@ -679,7 +680,7 @@ public class RequestHandler {
|
||||
try {
|
||||
objDetector =
|
||||
switch (family) {
|
||||
case RUBIK -> new RubikModel(modelProperties).load();
|
||||
case RUBIK -> new TFLiteModel(modelProperties, TFLiteSource.RUBIK).load();
|
||||
case RKNN -> new RknnModel(modelProperties).load();
|
||||
};
|
||||
} catch (RuntimeException e) {
|
||||
|
||||
Reference in New Issue
Block a user