From 8f560e5b1f171417331d10a0a3ade349bae67d8a Mon Sep 17 00:00:00 2001 From: Jade Date: Tue, 30 Jun 2026 01:02:47 +0800 Subject: [PATCH] Refactor Rubik OD code to generic TFLite OD code (#2516) --- .../NeuralNetworkModelManager.java | 5 ++-- .../vision/objects/RknnModel.java | 4 +-- .../vision/objects/RknnObjectDetector.java | 5 ++-- .../{RubikModel.java => TFLiteModel.java} | 17 +++++++----- ...etector.java => TFLiteObjectDetector.java} | 26 +++++++++---------- photon-server/build.gradle | 4 +++ .../photonvision/server/RequestHandler.java | 5 ++-- 7 files changed, 37 insertions(+), 29 deletions(-) rename photon-core/src/main/java/org/photonvision/vision/objects/{RubikModel.java => TFLiteModel.java} (83%) rename photon-core/src/main/java/org/photonvision/vision/objects/{RubikObjectDetector.java => TFLiteObjectDetector.java} (86%) diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java index 24cebab3e..db67fdefa 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java @@ -38,9 +38,10 @@ import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelPr import org.photonvision.common.hardware.Platform; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; +import org.photonvision.tflite.TFLiteJNI.TFLiteSource; import org.photonvision.vision.objects.Model; import org.photonvision.vision.objects.RknnModel; -import org.photonvision.vision.objects.RubikModel; +import org.photonvision.vision.objects.TFLiteModel; /** * Manages the loading of neural network models. @@ -360,7 +361,7 @@ public class NeuralNetworkModelManager { models.get(properties.family()).add(new RknnModel(properties)); } case RUBIK -> { - models.get(properties.family()).add(new RubikModel(properties)); + models.get(properties.family()).add(new TFLiteModel(properties, TFLiteSource.RUBIK)); } } logger.info( diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java b/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java index 5951a9208..b39c76b2f 100644 --- a/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java +++ b/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java @@ -19,7 +19,6 @@ package org.photonvision.vision.objects; import java.io.File; import java.nio.file.Path; -import org.opencv.core.Size; import org.photonvision.common.configuration.NeuralNetworkModelManager.Family; import org.photonvision.common.configuration.NeuralNetworkModelManager.Version; import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelProperties; @@ -79,8 +78,7 @@ public class RknnModel implements Model { } public ObjectDetector load() { - return new RknnObjectDetector( - this, new Size(properties.resolutionWidth(), properties.resolutionHeight())); + return new RknnObjectDetector(this); } public String toString() { diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/RknnObjectDetector.java b/photon-core/src/main/java/org/photonvision/vision/objects/RknnObjectDetector.java index dfe8d0b3e..35ae682e8 100644 --- a/photon-core/src/main/java/org/photonvision/vision/objects/RknnObjectDetector.java +++ b/photon-core/src/main/java/org/photonvision/vision/objects/RknnObjectDetector.java @@ -61,9 +61,10 @@ public class RknnObjectDetector implements ObjectDetector { * @param inputSize The required image dimensions for the model. Images will be {@link * Letterbox}ed to this shape. */ - public RknnObjectDetector(RknnModel model, Size inputSize) { + public RknnObjectDetector(RknnModel model) { this.model = model; - this.inputSize = inputSize; + this.inputSize = + new Size(model.properties.resolutionWidth(), model.properties.resolutionHeight()); // Create the detector objPointer = diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/RubikModel.java b/photon-core/src/main/java/org/photonvision/vision/objects/TFLiteModel.java similarity index 83% rename from photon-core/src/main/java/org/photonvision/vision/objects/RubikModel.java rename to photon-core/src/main/java/org/photonvision/vision/objects/TFLiteModel.java index c020b08fc..3a9031ee3 100644 --- a/photon-core/src/main/java/org/photonvision/vision/objects/RubikModel.java +++ b/photon-core/src/main/java/org/photonvision/vision/objects/TFLiteModel.java @@ -19,22 +19,26 @@ package org.photonvision.vision.objects; import java.io.File; import java.nio.file.Path; -import org.opencv.core.Size; import org.photonvision.common.configuration.NeuralNetworkModelManager.Family; import org.photonvision.common.configuration.NeuralNetworkModelManager.Version; import org.photonvision.common.configuration.NeuralNetworkModelsSettings.ModelProperties; +import org.photonvision.tflite.TFLiteJNI.TFLiteSource; -public class RubikModel implements Model { +public class TFLiteModel implements Model { public final File modelFile; public final ModelProperties properties; + public final TFLiteSource backend; /** - * Rubik model constructor. + * TFLite model constructor. * * @param properties The properties of the model. + * @param backend The backend of the model should run on. * @throws IllegalArgumentException */ - public RubikModel(ModelProperties properties) throws IllegalArgumentException { + public TFLiteModel(ModelProperties properties, TFLiteSource backend) + throws IllegalArgumentException { + this.backend = backend; modelFile = new File(properties.modelPath().toString()); if (!modelFile.exists()) { throw new IllegalArgumentException("Model file does not exist: " + modelFile); @@ -77,11 +81,10 @@ public class RubikModel implements Model { } public ObjectDetector load() { - return new RubikObjectDetector( - this, new Size(this.properties.resolutionWidth(), this.properties.resolutionHeight())); + return new TFLiteObjectDetector(this, backend); } public String toString() { - return "RubikModel{" + "modelFile=" + modelFile + ", properties=" + properties + '}'; + return "TFLiteModel{" + "modelFile=" + modelFile + ", properties=" + properties + '}'; } } diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/RubikObjectDetector.java b/photon-core/src/main/java/org/photonvision/vision/objects/TFLiteObjectDetector.java similarity index 86% rename from photon-core/src/main/java/org/photonvision/vision/objects/RubikObjectDetector.java rename to photon-core/src/main/java/org/photonvision/vision/objects/TFLiteObjectDetector.java index bc8aea13a..e8ff36f75 100644 --- a/photon-core/src/main/java/org/photonvision/vision/objects/RubikObjectDetector.java +++ b/photon-core/src/main/java/org/photonvision/vision/objects/TFLiteObjectDetector.java @@ -30,9 +30,9 @@ import org.photonvision.tflite.TFLiteJNI; import org.photonvision.tflite.TFLiteJNI.TFLiteSource; import org.photonvision.vision.pipe.impl.NeuralNetworkPipeResult; -/** Manages an object detector using the rubik backend. */ -public class RubikObjectDetector implements ObjectDetector { - private static final Logger logger = new Logger(RubikObjectDetector.class, LogGroup.General); +/** Manages an object detector using the TFLite backend. */ +public class TFLiteObjectDetector implements ObjectDetector { + private static final Logger logger = new Logger(TFLiteObjectDetector.class, LogGroup.General); private static final Cleaner cleaner = Cleaner.create(); @@ -45,26 +45,26 @@ public class RubikObjectDetector implements ObjectDetector { /** Pointer to the native object */ private final long ptr; - private final RubikModel model; + private final TFLiteModel model; private final Size inputSize; /** Returns the model in use by this detector. */ @Override - public RubikModel getModel() { + public TFLiteModel getModel() { return model; } /** - * Creates a new rubikObjectDetector from the given model. + * Creates a new TFLite detector from the given model. * * @param model The model to create the detector from. - * @param inputSize The required image dimensions for the model. Images will be {@link - * Letterbox}ed to this shape. + * @param source The backend to run the detector on. */ - public RubikObjectDetector(RubikModel model, Size inputSize) { + public TFLiteObjectDetector(TFLiteModel model, TFLiteSource backend) { this.model = model; - this.inputSize = inputSize; + this.inputSize = + new Size(model.properties.resolutionWidth(), model.properties.resolutionHeight()); // Create the detector try { @@ -72,7 +72,7 @@ public class RubikObjectDetector implements ObjectDetector { TFLiteJNI.create( model.modelFile.getPath().toString(), model.properties.version().ordinal(), - TFLiteSource.RUBIK.value()); + backend.value()); } catch (Exception e) { logger.error("Failed to create detector from path " + model.modelFile.getPath(), e); throw new RuntimeException( @@ -83,7 +83,7 @@ public class RubikObjectDetector implements ObjectDetector { logger.error( "Failed to create detector from path " + model.modelFile.getPath() - + ". Please ensure the model is valid and compatible with the Rubik backend."); + + ". Please ensure the model is valid and compatible with the TFLite backend."); throw new RuntimeException( "Failed to create detector from path " + model.modelFile.getPath()); } else if (!TFLiteJNI.isQuantized(ptr)) { @@ -107,7 +107,7 @@ public class RubikObjectDetector implements ObjectDetector { } /** - * Detects objects in the given input image using the rubikDetector. + * Detects objects in the given input image using the TFLite detector. * * @param in The input image to perform object detection on. * @param nmsThresh The threshold value for non-maximum suppression. diff --git a/photon-server/build.gradle b/photon-server/build.gradle index 5df27ea1d..46b0a1ed9 100644 --- a/photon-server/build.gradle +++ b/photon-server/build.gradle @@ -17,6 +17,10 @@ dependencies { // Needed for Javalin Runtime Logging implementation "org.slf4j:slf4j-simple:2.0.7" + + implementation("org.photonvision:tflite_jni-java:$tfliteVersion") { + transitive = false + } } group = 'org.photonvision' diff --git a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java index 0cd5380b0..f8d965810 100644 --- a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java +++ b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java @@ -55,12 +55,13 @@ import org.photonvision.common.networking.NetworkManager; import org.photonvision.common.util.ShellExec; import org.photonvision.common.util.TimedTaskManager; import org.photonvision.common.util.file.ProgramDirectoryUtilities; +import org.photonvision.tflite.TFLiteJNI.TFLiteSource; import org.photonvision.vision.calibration.CameraCalibrationCoefficients; import org.photonvision.vision.camera.CameraQuirk; import org.photonvision.vision.camera.PVCameraInfo; import org.photonvision.vision.objects.ObjectDetector; import org.photonvision.vision.objects.RknnModel; -import org.photonvision.vision.objects.RubikModel; +import org.photonvision.vision.objects.TFLiteModel; import org.photonvision.vision.processes.VisionSourceManager; import org.zeroturnaround.zip.ZipUtil; @@ -679,7 +680,7 @@ public class RequestHandler { try { objDetector = switch (family) { - case RUBIK -> new RubikModel(modelProperties).load(); + case RUBIK -> new TFLiteModel(modelProperties, TFLiteSource.RUBIK).load(); case RKNN -> new RknnModel(modelProperties).load(); }; } catch (RuntimeException e) {