diff --git a/docs/source/docs/objectDetection/about-object-detection.md b/docs/source/docs/objectDetection/about-object-detection.md index a155a2f98..8ef78db97 100644 --- a/docs/source/docs/objectDetection/about-object-detection.md +++ b/docs/source/docs/objectDetection/about-object-detection.md @@ -41,23 +41,14 @@ Power users only. This requires some setup, such as obtaining your own dataset a Before beginning, it is necessary to install the [rknn-toolkit2](https://github.com/airockchip/rknn-toolkit2). Then, install the relevant [Ultralytics repository](https://github.com/airockchip?tab=repositories&q=yolo&type=&language=&sort=) from this list. After training your model, export it to `rknn`. This will give you an `onnx` file, formatted for conversion. Copy this file to the relevant folder in [rknn_model_zoo](https://github.com/airockchip/rknn_model_zoo), and use the conversion script located there to convert it. If necessary, modify the script to provide the path to your training database for quantization. -## Uploading Custom Models +## Managing Custom Models :::{warning} PhotonVision currently ONLY supports 640x640 Ultralytics YOLOv5, YOLOv8, and YOLOv11 models trained and converted to `.rknn` format for RK3588 CPUs! Other models require different post-processing code and will NOT work. The model conversion process is also highly particular. Proceed with care. ::: :::{warning} -Non-quantized models are not supported! If you have the option, make sure quantization is enabled when exporting to .rknn format. This will represent the weights and activations of the model as 8-bit integers, instead of 32-bit floats which PhotonVision doesn't support. Quantized models are also much faster. +Non-quantized models are not supported! If you have the option, make sure quantization is enabled when exporting to .rknn format. This will represent the weights and activations of the model as 8-bit integers, instead of 32-bit floats which PhotonVision doesn't support. Quantized models are also much faster for a negligible loss in accuracy. ::: -In the settings, under `Device Control`, there's an option to upload a new object detection model. Naming convention -should be `name-verticalResolution-horizontalResolution-yolovXXX`. The -`name` should only include alphanumeric characters, periods, and underscores. Additionally, the labels -file ought to have the same name as the RKNN file, with `-labels` appended to the end. For -example, if the RKNN file is named `Algae_1.03.2025-640-640-yolov5s.rknn`, the labels file should be -named `Algae_1.03.2025-640-640-yolov5s-labels.txt`. - -:::{note} -Currently there is no way to delete custom models in the GUI, though this is a planned feature. To do this, you have to SSH into the coprocessor and delete the files manually from `/opt/photonvision/photonvision_config/models`. -::: +Custom models can now be managed from the Object Detection tab in settings. You can upload a custom model by clicking the "Upload Model" button, selecting your `.rknn` file, and filling out the property fields. Models can also be exported, both individually and in bulk. Models exported in bulk can be imported using the `import bulk` button. Models exported individually must be re-imported as an individual model, and all the relevant metadata is stored in the filename of the model. diff --git a/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue b/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue index afe880958..45d2b4d0e 100644 --- a/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue +++ b/photon-client/src/components/dashboard/tabs/ObjectDetectionTab.vue @@ -8,6 +8,7 @@ import { computed } from "vue"; import { useStateStore } from "@/stores/StateStore"; import { useSettingsStore } from "@/stores/settings/GeneralSettingsStore"; import { useDisplay } from "vuetify"; +import type { ObjectDetectionModelProperties } from "@/types/SettingTypes"; // TODO fix pipeline typing in order to fix this, the store settings call should be able to infer that only valid pipeline type settings are exposed based on pre-checks for the entire config section // Defer reference to store access method @@ -32,17 +33,32 @@ const interactiveCols = computed(() => ); // Filters out models that are not supported by the current backend, and returns a flattened list. -const supportedModels = computed(() => { +const supportedModels = computed(() => { const { availableModels, supportedBackends } = useSettingsStore().general; - return supportedBackends.flatMap((backend) => availableModels[backend] || []); + const isSupported = (model: ObjectDetectionModelProperties) => { + // Check if model's family is in the list of supported backends + return supportedBackends.some((backend: string) => backend.toLowerCase() === model.family.toLowerCase()); + }; + + // Filter models where the family is supported and flatten the list + return availableModels.filter(isSupported); }); const selectedModel = computed({ get: () => { - const index = supportedModels.value.indexOf(currentPipelineSettings.value.model); + const currentModel = currentPipelineSettings.value.model; + if (!currentModel) return undefined; + + const index = supportedModels.value.findIndex((model) => model.modelPath === currentModel.modelPath); return index === -1 ? undefined : index; }, - set: (v) => v && useCameraSettingsStore().changeCurrentPipelineSetting({ model: supportedModels.value[v] }, false) + + set: (v) => { + if (v !== undefined && v >= 0 && v < supportedModels.value.length) { + const newModel = supportedModels.value[v]; + useCameraSettingsStore().changeCurrentPipelineSetting({ model: newModel }, true); + } + } }); @@ -53,8 +69,9 @@ const selectedModel = computed({ label="Model" tooltip="The model used to detect objects in the camera feed" :select-cols="interactiveCols" - :items="supportedModels" + :items="supportedModels.map((model) => model.nickname)" /> + diff --git a/photon-client/src/components/settings/DeviceControlCard.vue b/photon-client/src/components/settings/DeviceControlCard.vue index ca2d51e3c..a3ab2c081 100644 --- a/photon-client/src/components/settings/DeviceControlCard.vue +++ b/photon-client/src/components/settings/DeviceControlCard.vue @@ -218,7 +218,7 @@ const nukePhotonConfigDirectory = () => { .catch((error) => { if (error.response) { useStateStore().showSnackbarMessage({ - message: "The backend is unable to fulfil the request to reset the device.", + message: "The backend is unable to fulfill the request to reset the device.", color: "error" }); } else if (error.request) { diff --git a/photon-client/src/components/settings/ObjectDetectionCard.vue b/photon-client/src/components/settings/ObjectDetectionCard.vue index 69c2942c9..0a418f00e 100644 --- a/photon-client/src/components/settings/ObjectDetectionCard.vue +++ b/photon-client/src/components/settings/ObjectDetectionCard.vue @@ -3,40 +3,37 @@ import { ref, computed, inject } from "vue"; import axios from "axios"; import { useStateStore } from "@/stores/StateStore"; import { useSettingsStore } from "@/stores/settings/GeneralSettingsStore"; +import type { ObjectDetectionModelProperties } from "@/types/SettingTypes"; +import pvInput from "@/components/common/pv-input.vue"; const showImportDialog = ref(false); -const importRKNNFile = ref(null); -const importLabelsFile = ref(null); +const showInfo = ref({ show: false, model: {} as ObjectDetectionModelProperties }); +const confirmDeleteDialog = ref({ show: false, model: {} as ObjectDetectionModelProperties }); +const showRenameDialog = ref({ + show: false, + model: {} as ObjectDetectionModelProperties, + newName: "" +}); -const host = inject("backendHost"); +const address = inject("backendHost"); -const areValidFileNames = (weights: string | null, labels: string | null) => { - const weightsRegex = /^([a-zA-Z0-9._]+)-(\d+)-(\d+)-(yolov(?:5|8|11)[nsmlx]*)\.rknn$/; - const labelsRegex = /^([a-zA-Z0-9._]+)-(\d+)-(\d+)-(yolov(?:5|8|11)[nsmlx]*)-labels\.txt$/; - - if (weights && labels) { - const weightsMatch = weights.match(weightsRegex); - const labelsMatch = labels.match(labelsRegex); - - if (weightsMatch && labelsMatch) { - return ( - weightsMatch[1] === labelsMatch[1] && - weightsMatch[2] === labelsMatch[2] && - weightsMatch[3] === labelsMatch[3] && - weightsMatch[4] === labelsMatch[4] - ); - } - } - return false; -}; +const importModelFile = ref(null); +const importLabels = ref(null); +const importHeight = ref(null); +const importWidth = ref(null); +const importVersion = ref(null); // TODO gray out the button when model is uploading const handleImport = async () => { - if (importRKNNFile.value === null || importLabelsFile.value === null) return; + if (importModelFile.value === null) return; const formData = new FormData(); - formData.append("rknn", importRKNNFile.value); - formData.append("labels", importLabelsFile.value); + + formData.append("modelFile", importModelFile.value); + formData.append("labels", importLabels.value?.toString() || ""); + formData.append("height", importHeight.value?.toString() || ""); + formData.append("width", importWidth.value?.toString() || ""); + formData.append("version", importVersion.value?.toString() || ""); useStateStore().showSnackbarMessage({ message: "Importing Object Detection Model...", @@ -45,7 +42,7 @@ const handleImport = async () => { }); axios - .post("/utils/importObjectDetectionModel", formData, { + .post("/objectdetection/import", formData, { headers: { "Content-Type": "multipart/form-data" } }) .then((response) => { @@ -75,87 +72,460 @@ const handleImport = async () => { showImportDialog.value = false; - importRKNNFile.value = null; - importLabelsFile.value = null; + importModelFile.value = null; + importLabels.value = null; + importHeight.value = null; + importWidth.value = null; + importVersion.value = null; +}; + +const deleteModel = async (model: ObjectDetectionModelProperties) => { + useStateStore().showSnackbarMessage({ + message: "Deleting Object Detection Model...", + color: "secondary", + timeout: -1 + }); + + axios + .post("/objectdetection/delete", { + modelPath: model.modelPath + }) + .then((response) => { + useStateStore().showSnackbarMessage({ + message: response.data.text || response.data, + color: "success" + }); + }) + .catch((error) => { + if (error.response) { + useStateStore().showSnackbarMessage({ + color: "error", + message: error.response.data.text || error.response.data + }); + } else if (error.request) { + useStateStore().showSnackbarMessage({ + color: "error", + message: "Error while trying to process the request! The backend didn't respond." + }); + } else { + useStateStore().showSnackbarMessage({ + color: "error", + message: "An error occurred while trying to process the request." + }); + } + }); + confirmDeleteDialog.value.show = false; +}; + +const renameModel = async (model: ObjectDetectionModelProperties, newName: string) => { + useStateStore().showSnackbarMessage({ + message: "Renaming Object Detection Model...", + color: "secondary", + timeout: -1 + }); + + axios + .post("/objectdetection/rename", { + modelPath: model.modelPath.replace("file:", ""), + newName: newName + }) + .then((response) => { + useStateStore().showSnackbarMessage({ + message: response.data.text || response.data, + color: "success" + }); + }) + .catch((error) => { + if (error.response) { + useStateStore().showSnackbarMessage({ + color: "error", + message: error.response.data.text || error.response.data + }); + } else if (error.request) { + useStateStore().showSnackbarMessage({ + color: "error", + message: "Error while trying to process the request! The backend didn't respond." + }); + } else { + useStateStore().showSnackbarMessage({ + color: "error", + message: "An error occurred while trying to process the request." + }); + } + }); + showRenameDialog.value.show = false; }; // Filters out models that are not supported by the current backend, and returns a flattened list. const supportedModels = computed(() => { const { availableModels, supportedBackends } = useSettingsStore().general; - return supportedBackends.flatMap((backend) => availableModels[backend] || []); + const isSupported = (model: any) => { + // Check if model's family is in the list of supported backends + return supportedBackends.some((backend: string) => backend.toLowerCase() === model.family.toLowerCase()); + }; + + // Filter models where the family is supported and flatten the list + return availableModels.filter(isSupported); }); + +const exportModels = ref(); +const openExportPrompt = () => { + exportModels.value.click(); +}; + +const exportIndividualModel = ref(); +const openExportIndividualModelPrompt = () => { + exportIndividualModel.value.click(); +}; + +const showNukeDialog = ref(false); +const expected = "Delete Models"; +const yesDeleteMyModelsText = ref(""); +const nukeModels = () => { + axios + .post("/objectdetection/nuke") + .then(() => { + useStateStore().showSnackbarMessage({ + message: "Successfully dispatched the clear models command.", + color: "success" + }); + }) + .catch((error) => { + if (error.response) { + useStateStore().showSnackbarMessage({ + message: "The backend is unable to fulfill the request to clear the models.", + color: "error" + }); + } else if (error.request) { + useStateStore().showSnackbarMessage({ + message: "Error while trying to process the request! The backend didn't respond.", + color: "error" + }); + } else { + useStateStore().showSnackbarMessage({ + message: "An error occurred while trying to process the request.", + color: "error" + }); + } + }); + showNukeDialog.value = false; +}; + +const showBulkImportDialog = ref(false); +const importFile = ref(null); +const handleBulkImport = () => { + if (importFile.value === null) return; + + const formData = new FormData(); + formData.append("data", importFile.value); + + axios + .post(`/objectdetection/bulkimport`, formData, { + headers: { "Content-Type": "multipart/form-data" }, + onUploadProgress: ({ progress }) => { + const uploadPercentage = (progress || 0) * 100.0; + if (uploadPercentage < 99.5) { + useStateStore().showSnackbarMessage({ + message: "Object Detection Models Upload in Process, " + uploadPercentage.toFixed(2) + "% complete", + color: "secondary", + timeout: -1 + }); + } else { + useStateStore().showSnackbarMessage({ + message: "Importing New Object Detection Models...", + color: "secondary", + timeout: -1 + }); + } + } + }) + .then((response) => { + useStateStore().showSnackbarMessage({ + message: response.data.text || response.data, + color: "success" + }); + }) + .catch((error) => { + if (error.response) { + useStateStore().showSnackbarMessage({ + color: "error", + message: error.response.data.text || error.response.data + }); + } else if (error.request) { + useStateStore().showSnackbarMessage({ + color: "error", + message: "Error while trying to process the request! The backend didn't respond." + }); + } else { + useStateStore().showSnackbarMessage({ + color: "error", + message: "An error occurred while trying to process the request." + }); + } + }); + + showImportDialog.value = false; + importFile.value = null; +}; diff --git a/photon-client/src/stores/settings/CameraSettingsStore.ts b/photon-client/src/stores/settings/CameraSettingsStore.ts index 48706263b..ef7799e3a 100644 --- a/photon-client/src/stores/settings/CameraSettingsStore.ts +++ b/photon-client/src/stores/settings/CameraSettingsStore.ts @@ -211,6 +211,7 @@ export const useCameraSettingsStore = defineStore("cameraSettings", { cameraUniqueName: cameraUniqueName } }; + if (updateStore) { this.changePipelineSettingsInStore(settings, cameraUniqueName); } diff --git a/photon-client/src/stores/settings/GeneralSettingsStore.ts b/photon-client/src/stores/settings/GeneralSettingsStore.ts index 3f374b147..f9614bcd3 100644 --- a/photon-client/src/stores/settings/GeneralSettingsStore.ts +++ b/photon-client/src/stores/settings/GeneralSettingsStore.ts @@ -27,7 +27,7 @@ export const useSettingsStore = defineStore("settings", { hardwareModel: undefined, hardwarePlatform: undefined, mrCalWorking: true, - availableModels: {}, + availableModels: [], supportedBackends: [] }, network: { diff --git a/photon-client/src/types/PipelineTypes.ts b/photon-client/src/types/PipelineTypes.ts index ca2e3f5ac..1b5631194 100644 --- a/photon-client/src/types/PipelineTypes.ts +++ b/photon-client/src/types/PipelineTypes.ts @@ -1,4 +1,5 @@ import type { WebsocketNumberPair } from "@/types/WebsocketDataTypes"; +import type { ObjectDetectionModelProperties } from "@/types/SettingTypes"; export enum PipelineType { DriverMode = 1, @@ -296,8 +297,9 @@ export interface ObjectDetectionPipelineSettings extends PipelineSettings { confidence: number; nms: number; box_thresh: number; - model: string; + model: ObjectDetectionModelProperties; } + export type ConfigurableObjectDetectionPipelineSettings = Partial< Omit > & @@ -313,7 +315,7 @@ export const DefaultObjectDetectionPipelineSettings: ObjectDetectionPipelineSett confidence: 0.9, nms: 0.45, box_thresh: 0.25, - model: "" + model: {} as ObjectDetectionModelProperties }; export interface Calibration3dPipelineSettings extends PipelineSettings { diff --git a/photon-client/src/types/SettingTypes.ts b/photon-client/src/types/SettingTypes.ts index 23e6595b5..0933d6b6e 100644 --- a/photon-client/src/types/SettingTypes.ts +++ b/photon-client/src/types/SettingTypes.ts @@ -8,10 +8,20 @@ export interface GeneralSettings { hardwareModel?: string; hardwarePlatform?: string; mrCalWorking: boolean; - availableModels: Record; + availableModels: ObjectDetectionModelProperties[]; supportedBackends: string[]; } +export interface ObjectDetectionModelProperties { + modelPath: string; + nickname: string; + labels: string[]; + resolutionWidth: number; + resolutionHeight: number; + family: "RKNN"; + version: "YOLOV5" | "YOLOV8" | "YOLOV11"; +} + export interface MetricData { cpuTemp?: string; cpuUtil?: string; diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/ConfigManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/ConfigManager.java index 02de8fb79..6f4371302 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/ConfigManager.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/ConfigManager.java @@ -34,6 +34,7 @@ import org.opencv.core.Size; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; import org.photonvision.common.util.file.FileUtils; +import org.photonvision.common.util.file.JacksonUtils; import org.photonvision.vision.processes.VisionSource; import org.zeroturnaround.zip.ZipUtil; @@ -51,7 +52,8 @@ public class ConfigManager { private final Thread settingsSaveThread; private long saveRequestTimestamp = -1; - // special case flag to disable flushing settings to disk at shutdown. Avoids the jvm shutdown + // special case flag to disable flushing settings to disk at shutdown. Avoids + // the jvm shutdown // hook overwriting the settings we just uploaded private boolean flushOnShutdown = true; private boolean allowWriteTask = true; @@ -62,7 +64,8 @@ public class ConfigManager { ATOMIC_ZIP } - // This logic decides which kind of ConfigManager we load as the default. If we want to switch + // This logic decides which kind of ConfigManager we load as the default. If we + // want to switch // back to the legacy config manager, change this constant private static final ConfigSaveStrategy m_saveStrat = ConfigSaveStrategy.SQL; @@ -109,18 +112,21 @@ public class ConfigManager { } catch (IOException e) { logger.error("Exception moving cameras to cameras_bak!", e); - // Try to just copy from cams to cams-bak instead of moving? Windows sometimes needs us to + // Try to just copy from cams to cams-bak instead of moving? Windows sometimes + // needs us to // do that try { org.apache.commons.io.FileUtils.copyDirectory(maybeCams, maybeCamsBak); } catch (IOException e1) { - // So we can't move to cams_bak, and we can't copy and delete either? We just have to give + // So we can't move to cams_bak, and we can't copy and delete either? We just + // have to give // up here on preserving the old folder logger.error("Exception while backup-copying cameras to cameras_bak!", e); e1.printStackTrace(); } - // Delete the directory because we were successfully able to load the config but were unable + // Delete the directory because we were successfully able to load the config but + // were unable // to save or copy the folder. if (maybeCams.exists()) FileUtils.deleteDirectory(maybeCams.toPath()); } @@ -217,6 +223,29 @@ public class ConfigManager { return out; } + public File getObjectDetectionExportAsZip() { + File out = + Path.of(System.getProperty("java.io.tmpdir"), "photonvision-object-detection-models.zip") + .toFile(); + // We create the properties file inside of the models directory so that when we zip it, it's + // included in the zip and simplifies packaging + File tempProperties = + Path.of(getModelsDirectory().toString(), "photonvision-object-detection-models.json") + .toFile(); + try { + JacksonUtils.serialize( + tempProperties.toPath(), this.getConfig().neuralNetworkPropertyManager()); + ZipUtil.pack(getModelsDirectory(), out); + // Now delete the tempProperties + if (tempProperties.exists()) { + Files.delete(tempProperties.toPath()); + } + } catch (Exception e) { + e.printStackTrace(); + } + return out; + } + public void setNetworkSettings(NetworkConfig networkConfig) { getConfig().setNetworkConfig(networkConfig); requestSave(); @@ -294,6 +323,10 @@ public class ConfigManager { return m_provider.saveUploadedAprilTagFieldLayout(uploadPath); } + public boolean saveUploadedNeuralNetworkProperties(Path uploadPath) { + return m_provider.saveUploadedNeuralNetworkProperties(uploadPath); + } + public void requestSave() { logger.trace("Requesting save..."); saveRequestTimestamp = System.currentTimeMillis(); diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/ConfigProvider.java b/photon-core/src/main/java/org/photonvision/common/configuration/ConfigProvider.java index 9964d0ca9..6e10e5ffc 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/ConfigProvider.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/ConfigProvider.java @@ -41,4 +41,6 @@ public abstract class ConfigProvider { public abstract boolean saveUploadedNetworkConfig(Path uploadPath); public abstract boolean saveUploadedAprilTagFieldLayout(Path uploadPath); + + public abstract boolean saveUploadedNeuralNetworkProperties(Path uploadPath); } diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/LegacyConfigProvider.java b/photon-core/src/main/java/org/photonvision/common/configuration/LegacyConfigProvider.java index f82e6f914..7c2308f93 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/LegacyConfigProvider.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/LegacyConfigProvider.java @@ -213,7 +213,12 @@ class LegacyConfigProvider extends ConfigProvider { this.config = new PhotonConfiguration( - hardwareConfig, hardwareSettings, networkConfig, atfl, cameraConfigurations); + hardwareConfig, + hardwareSettings, + networkConfig, + atfl, + new NeuralNetworkPropertyManager(), + cameraConfigurations); } @Override @@ -481,4 +486,12 @@ class LegacyConfigProvider extends ConfigProvider { public void unloadCameraConfigs() { this.config.getCameraConfigurations().clear(); } + + @Override + public boolean saveUploadedNeuralNetworkProperties(Path uploadPath) { + // I'm not implementing this cause nobody with the legacy config is gonna have one of these + + System.exit(1); + return false; + } } diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java index 966098977..e8dc0f389 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkModelManager.java @@ -25,16 +25,15 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardCopyOption; import java.util.ArrayList; -import java.util.Arrays; import java.util.Enumeration; import java.util.HashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.jar.JarEntry; import java.util.jar.JarFile; -import java.util.regex.Matcher; -import java.util.regex.Pattern; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager.ModelProperties; import org.photonvision.common.hardware.Platform; import org.photonvision.common.logging.LogGroup; import org.photonvision.common.logging.Logger; @@ -48,25 +47,43 @@ import org.photonvision.vision.objects.RknnModel; * also supports shipping pre-trained models as resources in the JAR. If the model has already been * extracted to the filesystem, it will not be extracted again. * - *

Each model must have a corresponding labels file. The labels file format is - * simply a list of string names per label, one label per line. The labels file must have the same - * name as the model file, but with the suffix -labels.txt instead of .rknn - * . + *

Each model must have a corresponding {@link ModelProperties} entry in {@link + * NeuralNetworkPropertyManager}. */ public class NeuralNetworkModelManager { /** Singleton instance of the NeuralNetworkModelManager */ private static NeuralNetworkModelManager INSTANCE; + /** + * This function stores the properties of the shipped object detection models. It is stored as a + * function so that it can be dynamic, to adjust for the models directory. + */ + private NeuralNetworkPropertyManager getShippedProperties(File modelsDirectory) { + NeuralNetworkPropertyManager nnProps = new NeuralNetworkPropertyManager(); + + nnProps.addModelProperties( + new ModelProperties( + Path.of(modelsDirectory.getAbsolutePath(), "algaeV1-640-640-yolov8n.rknn"), + "Algae v8n", + new LinkedList(List.of("Algae")), + 640, + 480, + Family.RKNN, + Version.YOLOV8)); + + return nnProps; + } + /** * Private constructor to prevent instantiation * * @return The NeuralNetworkModelManager instance */ private NeuralNetworkModelManager() { - ArrayList backends = new ArrayList<>(); + ArrayList backends = new ArrayList<>(); if (Platform.isRK3588()) { - backends.add(NeuralNetworkBackend.RKNN); + backends.add(Family.RKNN); } supportedBackends = backends; @@ -87,17 +104,17 @@ public class NeuralNetworkModelManager { /** Logger for the NeuralNetworkModelManager */ private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config); - public enum NeuralNetworkBackend { - RKNN(".rknn"); - - private String format; - - private NeuralNetworkBackend(String format) { - this.format = format; - } + public enum Family { + RKNN } - private final List supportedBackends; + public enum Version { + YOLOV5, + YOLOV8, + YOLOV11 + } + + private final List supportedBackends; /** * Retrieves the list of supported backends. @@ -113,30 +130,7 @@ public class NeuralNetworkModelManager { * *

The first model in the list is the default model. */ - private Map> models; - - /** - * Retrieves the deep neural network models available, in a format that can be used by the - * frontend. - * - * @return A map containing the available models, where the key is the backend and the value is a - * list of model names. - */ - public HashMap> getModels() { - HashMap> modelMap = new HashMap<>(); - if (models == null) { - return modelMap; - } - - models.forEach( - (backend, backendModels) -> { - ArrayList modelNames = new ArrayList<>(); - backendModels.forEach(model -> modelNames.add(model.getName())); - modelMap.put(backend.toString(), modelNames); - }); - - return modelMap; - } + private Map> models; /** * Retrieves the model with the specified name, assuming it is available under a supported @@ -144,19 +138,19 @@ public class NeuralNetworkModelManager { * *

If this method returns `Optional.of(..)` then the model should be safe to load. * - * @param modelName the name of the model to retrieve + * @param modelUID the unique identifier of the model to retrieve * @return an Optional containing the model if found, or an empty Optional if not found */ - public Optional getModel(String modelName) { + public Optional getModel(String modelUID) { if (models == null) { return Optional.empty(); } // Check if the model exists in any supported backend - for (NeuralNetworkBackend backend : supportedBackends) { + for (Family backend : supportedBackends) { if (models.containsKey(backend)) { Optional model = - models.get(backend).stream().filter(m -> m.getName().equals(modelName)).findFirst(); + models.get(backend).stream().filter(m -> m.getUID().equals(modelUID)).findFirst(); if (model.isPresent()) { return model; } @@ -168,66 +162,64 @@ public class NeuralNetworkModelManager { /** The default model when no model is specified. */ public Optional getDefaultModel() { - if (models == null) { - return Optional.empty(); - } - - if (supportedBackends.isEmpty()) { + if (models == null || supportedBackends.isEmpty()) { return Optional.empty(); } return models.get(supportedBackends.get(0)).stream().findFirst(); } - private void loadModel(File model) { + // Do checking later on, when we create the model object + private void loadModel(ModelProperties properties) { if (models == null) { models = new HashMap<>(); } - // Get the model extension and check if it is supported - String modelExtension = model.getName().substring(model.getName().lastIndexOf('.')); - if (modelExtension.equals(".txt")) { + if (properties == null) { + logger.error( + "Model properties are null, this could mean the models config was unable to be found in the database"); return; } - Optional backend = - Arrays.stream(NeuralNetworkBackend.values()) - .filter(b -> b.format.equals(modelExtension)) - .findFirst(); - - if (!backend.isPresent()) { - logger.warn("Model " + model.getName() + " has an unknown extension."); + if (!supportedBackends.contains(properties.family())) { + logger.warn( + "Model " + + properties.nickname() + + " has an unknown extension or is not supported on this hardware."); return; } - String labels = model.getAbsolutePath().replace(backend.get().format, "-labels.txt"); - if (!models.containsKey(backend.get())) { - models.put(backend.get(), new ArrayList<>()); + if (!models.containsKey(properties.family())) { + models.put(properties.family(), new ArrayList<>()); } try { - switch (backend.get()) { + switch (properties.family()) { case RKNN -> { - models.get(backend.get()).add(new RknnModel(model, labels)); - logger.info( - "Loaded model " + model.getName() + " for backend " + backend.get().toString()); + models.get(properties.family()).add(new RknnModel(properties)); } } + logger.info( + "Loaded model " + + properties.nickname() + + " for backend " + + properties.family().toString()); } catch (IllegalArgumentException e) { - logger.error("Failed to load model " + model.getName(), e); - } catch (IOException e) { - logger.error("Failed to read labels for model " + model.getName(), e); + logger.error("Failed to load model " + properties.nickname(), e); } } /** * Discovers DNN models from the specified folder. * - * @param modelsDirectory The folder where the models are stored + *

This makes the assumption that all of the models have their properties stored in the + * database */ - public void discoverModels(File modelsDirectory) { + public void discoverModels() { logger.info("Supported backends: " + supportedBackends); + File modelsDirectory = ConfigManager.getInstance().getModelsDirectory(); + if (!modelsDirectory.exists()) { logger.error("Models folder " + modelsDirectory.getAbsolutePath() + " does not exist."); return; @@ -238,7 +230,13 @@ public class NeuralNetworkModelManager { try { Files.walk(modelsDirectory.toPath()) .filter(Files::isRegularFile) - .forEach(path -> loadModel(path.toFile())); + .forEach( + path -> + loadModel( + ConfigManager.getInstance() + .getConfig() + .neuralNetworkPropertyManager() + .getModel(path))); } catch (IOException e) { logger.error("Failed to discover models at " + modelsDirectory.getAbsolutePath(), e); } @@ -246,8 +244,7 @@ public class NeuralNetworkModelManager { // After loading all of the models, sort them by name to ensure a consistent // ordering models.forEach( - (backend, backendModels) -> - backendModels.sort((a, b) -> a.getName().compareTo(b.getName()))); + (backend, backendModels) -> backendModels.sort((a, b) -> a.getUID().compareTo(b.getUID()))); // Log StringBuilder sb = new StringBuilder(); @@ -255,17 +252,17 @@ public class NeuralNetworkModelManager { models.forEach( (backend, backendModels) -> { sb.append(backend).append(" ["); - backendModels.forEach(model -> sb.append(model.getName()).append(", ")); + backendModels.forEach(model -> sb.append(model.getUID()).append(", ")); sb.append("] "); }); } /** - * Extracts models from the JAR and copies them to disk. - * - * @param modelsDirectory the directory on disk to save models + * Extracts models from the JAR and copies them to disk. Also copies properties into the database. */ - public void extractModels(File modelsDirectory) { + public void extractModels() { + File modelsDirectory = ConfigManager.getInstance().getModelsDirectory(); + if (!modelsDirectory.exists() && !modelsDirectory.mkdirs()) { throw new RuntimeException("Failed to create directory: " + modelsDirectory); } @@ -302,67 +299,85 @@ public class NeuralNetworkModelManager { } catch (IOException | URISyntaxException e) { logger.error("Error extracting models", e); } + + ConfigManager.getInstance() + .getConfig() + .setNeuralNetworkProperties( + getShippedProperties(modelsDirectory) + .sum(ConfigManager.getInstance().getConfig().neuralNetworkPropertyManager())); } - private static Pattern modelPattern = - Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)\\.rknn$"); + public boolean clearModels() { + File modelsDirectory = ConfigManager.getInstance().getModelsDirectory(); - private static Pattern labelsPattern = - Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)-labels\\.txt$"); - - /** - * Check naming conventions for models and labels. - * - *

This is static as it is not dependent on the state of the class. - * - * @param modelName the name of the model - * @param labelsName the name of the labels file - * @throws IllegalArgumentException if the names are invalid - */ - public static void verifyRKNNNames(String modelName, String labelsName) { - // check null - if (modelName == null || labelsName == null) { - throw new IllegalArgumentException("Model name and labels name cannot be null"); + if (modelsDirectory.exists()) { + try { + Files.walk(modelsDirectory.toPath()) + .sorted((a, b) -> b.compareTo(a)) + .forEach( + path -> { + try { + Files.delete(path); + } catch (IOException e) { + logger.error("Failed to delete file: " + path, e); + } + }); + } catch (IOException e) { + logger.error("Failed to delete models directory", e); + return false; + } } - // These patterns check that the naming convention of - // name-widthResolution-heightResolution-modelType is followed - - Matcher modelMatcher = modelPattern.matcher(modelName); - Matcher labelsMatcher = labelsPattern.matcher(labelsName); - - if (!modelMatcher.matches() || !labelsMatcher.matches()) { - throw new IllegalArgumentException( - "Model name and labels name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn and name-widthResolution-heightResolution-modelType-labels.txt"); - } - - if (!modelMatcher.group(1).equals(labelsMatcher.group(1)) - || !modelMatcher.group(2).equals(labelsMatcher.group(2)) - || !modelMatcher.group(3).equals(labelsMatcher.group(3)) - || !modelMatcher.group(4).equals(labelsMatcher.group(4))) { - throw new IllegalArgumentException("Model name and labels name must be matching."); - } + // Delete model info + return ConfigManager.getInstance().getConfig().neuralNetworkPropertyManager().clear(); } - /** - * Parse RKNN name and return the name, width, height, and model type. - * - *

This is static as it is not dependent on the state of the class. - * - * @param modelName the name of the model - * @throws IllegalArgumentException if the model name does not follow the naming convention - * @return an array containing the name, width, height, and model type - */ - public static String[] parseRKNNName(String modelName) { - Matcher modelMatcher = modelPattern.matcher(modelName); + public File exportSingleModel(String modelPath) { + try { + File modelFile = new File(modelPath); + if (!modelFile.exists()) { + logger.error("Model file does not exist: " + modelFile.getAbsolutePath()); + return null; + } - if (!modelMatcher.matches()) { - throw new IllegalArgumentException( - "Model name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn"); + ModelProperties properties = + ConfigManager.getInstance() + .getConfig() + .neuralNetworkPropertyManager() + .getModel(Path.of(modelPath)); + + String fileName = ""; + String suffix = modelFile.getName().substring(modelFile.getName().lastIndexOf('.')); + if (properties != null) { + fileName = + String.format( + "%s-%s-%s-%dx%d-%s", + properties.nickname().replace(" ", ""), + properties.family(), + properties.version(), + properties.resolutionWidth(), + properties.resolutionHeight(), + String.join("_", properties.labels())); + } else { + fileName = new File(modelPath).getName(); + } + + try { + var out = Files.createTempFile(fileName, suffix); + Files.copy( + modelFile.toPath(), + out, + StandardCopyOption.REPLACE_EXISTING, + StandardCopyOption.COPY_ATTRIBUTES); + return out.toFile(); + } catch (Exception e) { + e.printStackTrace(); + logger.error("Failed to export model file: " + modelFile.getAbsolutePath(), e); + return null; + } + } catch (Exception e) { + logger.error("Failed to export model file: " + modelPath, e); + return null; } - - return new String[] { - modelMatcher.group(1), modelMatcher.group(2), modelMatcher.group(3), modelMatcher.group(4) - }; } } diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkPropertyManager.java b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkPropertyManager.java new file mode 100644 index 000000000..47067bfa2 --- /dev/null +++ b/photon-core/src/main/java/org/photonvision/common/configuration/NeuralNetworkPropertyManager.java @@ -0,0 +1,163 @@ +/* + * Copyright (C) Photon Vision. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package org.photonvision.common.configuration; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import java.nio.file.Path; +import java.util.HashMap; +import java.util.LinkedList; +import org.photonvision.common.configuration.NeuralNetworkModelManager.Family; +import org.photonvision.common.configuration.NeuralNetworkModelManager.Version; + +public class NeuralNetworkPropertyManager { + /* + * The properties of the model. This is used to determine which model to load. + * The only family + * currently supported is RKNN. + */ + public record ModelProperties( + @JsonProperty("modelPath") Path modelPath, + @JsonProperty("nickname") String nickname, + @JsonProperty("labels") LinkedList labels, + @JsonProperty("resolutionWidth") int resolutionWidth, + @JsonProperty("resolutionHeight") int resolutionHeight, + @JsonProperty("family") Family family, + @JsonProperty("version") Version version) { + @JsonCreator + public ModelProperties { + // Record constructor is automatically annotated with @JsonCreator + } + } + + // The path to the model is used as the key in the map because it is unique to + // the model, and should not change + @JsonProperty("modelPathToProperties") + private HashMap modelPathToProperties = + new HashMap(); + + /** + * Constructor for the NeuralNetworkProperties class. + * + *

This object holds a LinkedList of {@link ModelProperties} objects + */ + public NeuralNetworkPropertyManager() {} + + /** + * Constructor for the NeuralNetworkProperties class. + * + *

This object holds a LinkedList of {@link ModelProperties} objects. + * + * @param modelPropertiesList When the class is constructed, it will hold the provided list + */ + public NeuralNetworkPropertyManager(HashMap modelPropertiesList) {} + + @Override + public String toString() { + String toReturn = ""; + + toReturn += "NeuralNetworkProperties ["; + + toReturn += modelPathToProperties.toString() + "]"; + + return toReturn; + } + + /** + * Add a model to the list of models. + * + * @param modelProperties + */ + public void addModelProperties(ModelProperties modelProperties) { + modelPathToProperties.put(modelProperties.modelPath, modelProperties); + } + + /** + * Add two Neural Network Properties together. + * + *

Any properties that are the same will be overwritten by the second + * + * @param nnProps + * @return itself, so it can be chained and used fluently + */ + public NeuralNetworkPropertyManager sum(NeuralNetworkPropertyManager nnProps) { + modelPathToProperties.putAll(nnProps.modelPathToProperties); + + return this; + } + + /** + * Remove a model from the list of models. + * + * @param modelPath + * @return True if the model was removed, false if it was not found + */ + public boolean removeModel(Path modelPath) { + return modelPathToProperties.remove(modelPath) != null; + } + + /** + * Get the model properties for a given model path. + * + * @param modelPath + * @return {@link ModelProperties} object + */ + public ModelProperties getModel(Path modelPath) { + return modelPathToProperties.get(modelPath); + } + + /** + * Get all models + * + * @return A list of all models + */ + public ModelProperties[] getModels() { + return modelPathToProperties.values().toArray(new ModelProperties[0]); + } + + /** + * Change the nickname of a {@link ModelProperties} object. + * + * @param modelPath + * @param newName + * @return True if the model was found and renamed, false if it was not found + */ + public boolean renameModel(Path modelPath, String newName) { + ModelProperties temp = modelPathToProperties.get(modelPath); + if (temp != null) { + modelPathToProperties.remove(modelPath); + modelPathToProperties.put( + modelPath, + new ModelProperties( + temp.modelPath, + newName, + temp.labels, + temp.resolutionWidth, + temp.resolutionHeight, + temp.family, + temp.version)); + return true; + } + return false; + } + + public boolean clear() { + modelPathToProperties.clear(); + return true; + } +} diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/PhotonConfiguration.java b/photon-core/src/main/java/org/photonvision/common/configuration/PhotonConfiguration.java index 66018ec6c..9b01cff12 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/PhotonConfiguration.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/PhotonConfiguration.java @@ -28,14 +28,22 @@ public class PhotonConfiguration { private final HardwareSettings hardwareSettings; private NetworkConfig networkConfig; private AprilTagFieldLayout atfl; + private NeuralNetworkPropertyManager neuralNetworkProperties; private HashMap cameraConfigurations; public PhotonConfiguration( HardwareConfig hardwareConfig, HardwareSettings hardwareSettings, NetworkConfig networkConfig, - AprilTagFieldLayout atfl) { - this(hardwareConfig, hardwareSettings, networkConfig, atfl, new HashMap<>()); + AprilTagFieldLayout atfl, + NeuralNetworkPropertyManager neuralNetworkProperties) { + this( + hardwareConfig, + hardwareSettings, + networkConfig, + atfl, + neuralNetworkProperties, + new HashMap<>()); } public PhotonConfiguration( @@ -43,10 +51,12 @@ public class PhotonConfiguration { HardwareSettings hardwareSettings, NetworkConfig networkConfig, AprilTagFieldLayout atfl, + NeuralNetworkPropertyManager neuralNetworkProperties, HashMap cameraConfigurations) { this.hardwareConfig = hardwareConfig; this.hardwareSettings = hardwareSettings; this.networkConfig = networkConfig; + this.neuralNetworkProperties = neuralNetworkProperties; this.cameraConfigurations = cameraConfigurations; this.atfl = atfl; } @@ -56,7 +66,8 @@ public class PhotonConfiguration { new HardwareConfig(), new HardwareSettings(), new NetworkConfig(), - new AprilTagFieldLayout(List.of(), 0, 0)); + new AprilTagFieldLayout(List.of(), 0, 0), + new NeuralNetworkPropertyManager()); } public HardwareConfig getHardwareConfig() { @@ -75,6 +86,10 @@ public class PhotonConfiguration { return atfl; } + public NeuralNetworkPropertyManager neuralNetworkPropertyManager() { + return neuralNetworkProperties; + } + public void setApriltagFieldLayout(AprilTagFieldLayout atfl) { this.atfl = atfl; } @@ -83,6 +98,10 @@ public class PhotonConfiguration { this.networkConfig = networkConfig; } + public void setNeuralNetworkProperties(NeuralNetworkPropertyManager neuralNetworkProperties) { + this.neuralNetworkProperties = neuralNetworkProperties; + } + public HashMap getCameraConfigurations() { return cameraConfigurations; } @@ -121,6 +140,8 @@ public class PhotonConfiguration { + networkConfig + "\n atfl=" + atfl + + "\n neuralNetworkProperties=" + + neuralNetworkProperties + "\n cameraConfigurations=" + cameraConfigurations + "\n]"; diff --git a/photon-core/src/main/java/org/photonvision/common/configuration/SqlConfigProvider.java b/photon-core/src/main/java/org/photonvision/common/configuration/SqlConfigProvider.java index bae7c4ced..b9c6c17ec 100644 --- a/photon-core/src/main/java/org/photonvision/common/configuration/SqlConfigProvider.java +++ b/photon-core/src/main/java/org/photonvision/common/configuration/SqlConfigProvider.java @@ -56,6 +56,7 @@ public class SqlConfigProvider extends ConfigProvider { static final String HARDWARE_CONFIG = "hardwareConfig"; static final String HARDWARE_SETTINGS = "hardwareSettings"; static final String ATFL_CONFIG_FILE = "apriltagFieldLayout"; + static final String NEURAL_NETWORK_PROPERTIES = "neuralNetworkProperties"; } private static final String dbName = "photon.sqlite"; @@ -263,6 +264,7 @@ public class SqlConfigProvider extends ConfigProvider { HardwareSettings hardwareSettings; NetworkConfig networkConfig; AprilTagFieldLayout atfl; + NeuralNetworkPropertyManager nnProps; try { hardwareConfig = @@ -310,6 +312,16 @@ public class SqlConfigProvider extends ConfigProvider { } } + try { + nnProps = + JacksonUtils.deserialize( + getOneConfigFile(conn, GlobalKeys.NEURAL_NETWORK_PROPERTIES), + NeuralNetworkPropertyManager.class); + } catch (IOException e) { + logger.error("Could not deserialize neural network properties! Loading defaults", e); + nnProps = new NeuralNetworkPropertyManager(); + } + var cams = loadCameraConfigs(conn); try { @@ -319,7 +331,8 @@ public class SqlConfigProvider extends ConfigProvider { } this.config = - new PhotonConfiguration(hardwareConfig, hardwareSettings, networkConfig, atfl, cams); + new PhotonConfiguration( + hardwareConfig, hardwareSettings, networkConfig, atfl, nnProps, cams); } } @@ -442,6 +455,7 @@ public class SqlConfigProvider extends ConfigProvider { private boolean skipSavingHWSet = false; private boolean skipSavingNWCfg = false; private boolean skipSavingAPRTG = false; + private boolean skipSavingNNProps = false; private void saveGlobal(Connection conn) { PreparedStatement statement1 = null; @@ -483,6 +497,16 @@ public class SqlConfigProvider extends ConfigProvider { statement3.close(); } + if (!skipSavingNNProps) { + statement3 = conn.prepareStatement(sqlString); + addFile( + statement3, + GlobalKeys.NEURAL_NETWORK_PROPERTIES, + JacksonUtils.serializeToString(config.neuralNetworkPropertyManager())); + statement3.executeUpdate(); + statement3.close(); + } + } catch (SQLException | IOException e) { logger.error("Err saving global", e); try { @@ -565,6 +589,12 @@ public class SqlConfigProvider extends ConfigProvider { return saveOneFile(GlobalKeys.ATFL_CONFIG_FILE, uploadPath); } + @Override + public boolean saveUploadedNeuralNetworkProperties(Path uploadPath) { + skipSavingNNProps = true; + return saveOneFile(GlobalKeys.NEURAL_NETWORK_PROPERTIES, uploadPath); + } + private HashMap loadCameraConfigs(Connection conn) { HashMap loadedConfigurations = new HashMap<>(); diff --git a/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIGeneralSettings.java b/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIGeneralSettings.java index 00ece4d1c..cb981e369 100644 --- a/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIGeneralSettings.java +++ b/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIGeneralSettings.java @@ -17,16 +17,15 @@ package org.photonvision.common.dataflow.websocket; -import java.util.ArrayList; import java.util.List; -import java.util.Map; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager; public class UIGeneralSettings { public UIGeneralSettings( String version, String gpuAcceleration, boolean mrCalWorking, - Map> availableModels, + NeuralNetworkPropertyManager.ModelProperties[] availableModels, List supportedBackends, String hardwareModel, String hardwarePlatform) { @@ -42,7 +41,7 @@ public class UIGeneralSettings { public String version; public String gpuAcceleration; public boolean mrCalWorking; - public Map> availableModels; + public NeuralNetworkPropertyManager.ModelProperties[] availableModels; public List supportedBackends; public String hardwareModel; public String hardwarePlatform; diff --git a/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIPhotonConfiguration.java b/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIPhotonConfiguration.java index 4d5758bf7..971d853a3 100644 --- a/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIPhotonConfiguration.java +++ b/photon-core/src/main/java/org/photonvision/common/dataflow/websocket/UIPhotonConfiguration.java @@ -54,7 +54,7 @@ public class UIPhotonConfiguration { // TODO add support for other types of GPU accel LibCameraJNILoader.isSupported() ? "Zerocopy Libcamera Working" : "", MrCalJNILoader.getInstance().isLoaded(), - NeuralNetworkModelManager.getInstance().getModels(), + c.neuralNetworkPropertyManager().getModels(), NeuralNetworkModelManager.getInstance().getSupportedBackends(), c.getHardwareConfig().deviceName().isEmpty() ? Platform.getHardwareModel() diff --git a/photon-core/src/main/java/org/photonvision/common/util/file/JacksonUtils.java b/photon-core/src/main/java/org/photonvision/common/util/file/JacksonUtils.java index 40d8e6a4f..d97906031 100644 --- a/photon-core/src/main/java/org/photonvision/common/util/file/JacksonUtils.java +++ b/photon-core/src/main/java/org/photonvision/common/util/file/JacksonUtils.java @@ -17,9 +17,15 @@ package org.photonvision.common.util.file; +import com.fasterxml.jackson.core.JsonGenerator; +import com.fasterxml.jackson.core.JsonParser; import com.fasterxml.jackson.core.json.JsonReadFeature; +import com.fasterxml.jackson.databind.DeserializationContext; import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.fasterxml.jackson.databind.JsonSerializer; import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializerProvider; import com.fasterxml.jackson.databind.deser.std.StdDeserializer; import com.fasterxml.jackson.databind.json.JsonMapper; import com.fasterxml.jackson.databind.jsontype.BasicPolymorphicTypeValidator; @@ -31,6 +37,7 @@ import java.io.FileDescriptor; import java.io.FileOutputStream; import java.io.IOException; import java.nio.file.Path; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; import org.eclipse.jetty.io.EofException; @@ -38,41 +45,89 @@ import org.eclipse.jetty.io.EofException; public class JacksonUtils { public static class UIMap extends HashMap {} + // Custom Path serializer that outputs just the path string without file:/ prefix + public static class PathSerializer extends JsonSerializer { + @Override + public void serialize(Path value, JsonGenerator gen, SerializerProvider serializers) + throws IOException { + if (value == null) { + gen.writeNull(); + } else { + gen.writeString(value.toString()); + } + } + } + + // Custom Path deserializer that reads path strings + public static class PathDeserializer extends JsonDeserializer { + @Override + public Path deserialize(JsonParser p, DeserializationContext ctxt) throws IOException { + String pathString = p.getValueAsString(); + if (pathString == null || pathString.isEmpty()) { + return null; + } + + // Handle case where old serialized data might still have file:/ prefix + if (pathString.startsWith("file:/")) { + pathString = pathString.substring(6); // Remove "file:/" prefix + } + + return Paths.get(pathString); + } + } + + // Custom Path key deserializer for Maps with Path keys + public static class PathKeyDeserializer extends com.fasterxml.jackson.databind.KeyDeserializer { + @Override + public Object deserializeKey(String key, DeserializationContext ctxt) throws IOException { + if (key == null || key.isEmpty()) { + return null; + } + + // Handle case where old serialized data might still have file:/ prefix + if (key.startsWith("file:/")) { + key = key.substring(6); // Remove "file:/" prefix + } + + return Paths.get(key); + } + } + + // Helper method to create ObjectMapper with Path serialization support + private static ObjectMapper createObjectMapperWithPathSupport(Class baseType) { + PolymorphicTypeValidator ptv = + BasicPolymorphicTypeValidator.builder().allowIfBaseType(baseType).build(); + + SimpleModule pathModule = new SimpleModule(); + pathModule.addSerializer(Path.class, new PathSerializer()); + pathModule.addDeserializer(Path.class, new PathDeserializer()); + pathModule.addKeyDeserializer(Path.class, new PathKeyDeserializer()); + + return JsonMapper.builder() + .configure(JsonReadFeature.ALLOW_JAVA_COMMENTS, true) + .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) + .activateDefaultTyping(ptv, ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT) + .addModule(pathModule) + .build(); + } + public static void serialize(Path path, T object) throws IOException { serialize(path, object, true); } public static String serializeToString(T object) throws IOException { - PolymorphicTypeValidator ptv = - BasicPolymorphicTypeValidator.builder().allowIfBaseType(object.getClass()).build(); - ObjectMapper objectMapper = - JsonMapper.builder() - .activateDefaultTyping(ptv, ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT) - .build(); + ObjectMapper objectMapper = createObjectMapperWithPathSupport(object.getClass()); return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(object); } public static void serialize(Path path, T object, boolean forceSync) throws IOException { - PolymorphicTypeValidator ptv = - BasicPolymorphicTypeValidator.builder().allowIfBaseType(object.getClass()).build(); - ObjectMapper objectMapper = - JsonMapper.builder() - .activateDefaultTyping(ptv, ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT) - .build(); + ObjectMapper objectMapper = createObjectMapperWithPathSupport(object.getClass()); String json = objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(object); saveJsonString(json, path, forceSync); } public static T deserialize(Map s, Class ref) throws IOException { - PolymorphicTypeValidator ptv = - BasicPolymorphicTypeValidator.builder().allowIfBaseType(ref).build(); - ObjectMapper objectMapper = - JsonMapper.builder() - .configure(JsonReadFeature.ALLOW_JAVA_COMMENTS, true) - .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - .activateDefaultTyping(ptv, ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT) - .build(); - + ObjectMapper objectMapper = createObjectMapperWithPathSupport(ref); return objectMapper.convertValue(s, ref); } @@ -81,28 +136,14 @@ public class JacksonUtils { throw new EofException("Provided empty string for class " + ref.getName()); } - PolymorphicTypeValidator ptv = - BasicPolymorphicTypeValidator.builder().allowIfBaseType(ref).build(); - ObjectMapper objectMapper = - JsonMapper.builder() - .configure(JsonReadFeature.ALLOW_JAVA_COMMENTS, true) - .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - .enable(DeserializationFeature.READ_UNKNOWN_ENUM_VALUES_AS_NULL) - .activateDefaultTyping(ptv, ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT) - .build(); + ObjectMapper objectMapper = createObjectMapperWithPathSupport(ref); + objectMapper.enable(DeserializationFeature.READ_UNKNOWN_ENUM_VALUES_AS_NULL); return objectMapper.readValue(s, ref); } public static T deserialize(Path path, Class ref) throws IOException { - PolymorphicTypeValidator ptv = - BasicPolymorphicTypeValidator.builder().allowIfBaseType(ref).build(); - ObjectMapper objectMapper = - JsonMapper.builder() - .configure(JsonReadFeature.ALLOW_JAVA_COMMENTS, true) - .configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false) - .activateDefaultTyping(ptv, ObjectMapper.DefaultTyping.JAVA_LANG_OBJECT) - .build(); + ObjectMapper objectMapper = createObjectMapperWithPathSupport(ref); File jsonFile = new File(path.toString()); if (jsonFile.exists() && jsonFile.length() > 0) { return objectMapper.readValue(jsonFile, ref); @@ -115,6 +156,12 @@ public class JacksonUtils { ObjectMapper objectMapper = new ObjectMapper(); SimpleModule module = new SimpleModule(); module.addDeserializer(ref, deserializer); + + // Add Path support to custom deserializer case as well + module.addSerializer(Path.class, new PathSerializer()); + module.addDeserializer(Path.class, new PathDeserializer()); + module.addKeyDeserializer(Path.class, new PathKeyDeserializer()); + objectMapper.registerModule(module); File jsonFile = new File(path.toString()); @@ -135,6 +182,12 @@ public class JacksonUtils { ObjectMapper objectMapper = new ObjectMapper(); SimpleModule module = new SimpleModule(); module.addSerializer(ref, serializer); + + // Add Path support to custom serializer case as well + module.addSerializer(Path.class, new PathSerializer()); + module.addDeserializer(Path.class, new PathDeserializer()); + module.addKeyDeserializer(Path.class, new PathKeyDeserializer()); + objectMapper.registerModule(module); String json = objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(object); saveJsonString(json, path, forceSync); diff --git a/photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java b/photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java index e4bb3d52c..2e05f13d1 100644 --- a/photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java +++ b/photon-core/src/main/java/org/photonvision/jni/RknnObjectDetector.java @@ -68,7 +68,11 @@ public class RknnObjectDetector implements ObjectDetector { // Create the detector objPointer = - RknnJNI.create(model.modelFile.getPath(), model.labels.size(), model.version.ordinal(), -1); + RknnJNI.create( + model.modelFile.getPath(), + model.properties.labels().size(), + model.properties.version().ordinal(), + -1); if (objPointer <= 0) { throw new RuntimeException( "Failed to create detector from path " + model.modelFile.getPath()); @@ -87,7 +91,7 @@ public class RknnObjectDetector implements ObjectDetector { */ @Override public List getClasses() { - return model.labels; + return model.properties.labels(); } /** diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/Model.java b/photon-core/src/main/java/org/photonvision/vision/objects/Model.java index a634898fd..cb60bb712 100644 --- a/photon-core/src/main/java/org/photonvision/vision/objects/Model.java +++ b/photon-core/src/main/java/org/photonvision/vision/objects/Model.java @@ -17,8 +17,17 @@ package org.photonvision.vision.objects; +import org.photonvision.common.configuration.NeuralNetworkModelManager.Family; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager.ModelProperties; + public interface Model { public ObjectDetector load(); - public String getName(); + public String getUID(); + + public String getNickname(); + + public Family getFamily(); + + public ModelProperties getProperties(); } diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/NullModel.java b/photon-core/src/main/java/org/photonvision/vision/objects/NullModel.java index e3ecf42ef..7238ecd69 100644 --- a/photon-core/src/main/java/org/photonvision/vision/objects/NullModel.java +++ b/photon-core/src/main/java/org/photonvision/vision/objects/NullModel.java @@ -19,6 +19,8 @@ package org.photonvision.vision.objects; import java.util.List; import org.opencv.core.Mat; +import org.photonvision.common.configuration.NeuralNetworkModelManager.Family; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager.ModelProperties; import org.photonvision.vision.pipe.impl.NeuralNetworkPipeResult; /** @@ -41,10 +43,25 @@ public class NullModel implements Model, ObjectDetector { } @Override - public String getName() { + public String getUID() { return "NullModel"; } + @Override + public String getNickname() { + return "NullModel"; + } + + @Override + public Family getFamily() { + return null; + } + + @Override + public ModelProperties getProperties() { + return null; + } + @Override public void release() { // Do nothing diff --git a/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java b/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java index 038c6695a..98bda55d6 100644 --- a/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java +++ b/photon-core/src/main/java/org/photonvision/vision/objects/RknnModel.java @@ -18,77 +18,68 @@ package org.photonvision.vision.objects; import java.io.File; -import java.io.IOException; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.util.List; import org.opencv.core.Size; -import org.photonvision.common.configuration.NeuralNetworkModelManager; +import org.photonvision.common.configuration.NeuralNetworkModelManager.Family; +import org.photonvision.common.configuration.NeuralNetworkModelManager.Version; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager.ModelProperties; import org.photonvision.jni.RknnObjectDetector; -import org.photonvision.rknn.RknnJNI; public class RknnModel implements Model { public final File modelFile; - public final RknnJNI.ModelVersion version; - public final List labels; - public final Size inputSize; - - /** - * Determines the model version based on the model's filename. - * - *

"yolov5" -> "YOLO_V5" - * - *

"yolov8" -> "YOLO_V8" - * - *

"yolov11" -> "YOLO_V11" - * - * @param modelName The model's filename - * @return The model version - */ - private static RknnJNI.ModelVersion getModelVersion(String modelName) - throws IllegalArgumentException { - if (modelName.contains("yolov5")) { - return RknnJNI.ModelVersion.YOLO_V5; - } else if (modelName.contains("yolov8")) { - return RknnJNI.ModelVersion.YOLO_V8; - } else if (modelName.contains("yolov11")) { - return RknnJNI.ModelVersion.YOLO_V11; - } else { - throw new IllegalArgumentException("Unknown model version for model " + modelName); - } - } + public final ModelProperties properties; /** * rknn model constructor. * - * @param modelFile path to model on disk. Format: `name-width-height-model.rknn` - * @param labels path to labels file on disk + * @param properties The properties of the model. * @throws IllegalArgumentException */ - public RknnModel(File modelFile, String labels) throws IllegalArgumentException, IOException { - this.modelFile = modelFile; - - // parseRKNNName throws an IllegalArgumentException if the model name is invalid - String[] parts = NeuralNetworkModelManager.parseRKNNName(modelFile.getName()); - - this.version = getModelVersion(parts[3]); - - int width = Integer.parseInt(parts[1]); - int height = Integer.parseInt(parts[2]); - this.inputSize = new Size(width, height); - - try { - this.labels = Files.readAllLines(Paths.get(labels)); - } catch (IOException e) { - throw new IllegalArgumentException("Failed to read labels file " + labels, e); + public RknnModel(ModelProperties properties) throws IllegalArgumentException { + modelFile = new File(properties.modelPath().toString()); + if (!modelFile.exists()) { + throw new IllegalArgumentException("Model file does not exist: " + modelFile); } + + if (properties.labels() == null || properties.labels().isEmpty()) { + throw new IllegalArgumentException("Labels must be provided"); + } + + if (properties.resolutionWidth() <= 0 || properties.resolutionHeight() <= 0) { + throw new IllegalArgumentException("Resolution must be greater than 0"); + } + + if (properties.family() != Family.RKNN) { + throw new IllegalArgumentException("Model family must be RKNN"); + } + + if (properties.version() != Version.YOLOV5 + && properties.version() != Version.YOLOV8 + && properties.version() != Version.YOLOV11) { + throw new IllegalArgumentException("Model version must be YOLOV5, YOLOV8, or YOLOV11"); + } + + this.properties = properties; } - public String getName() { - return modelFile.getName(); + /** Return the unique identifier for the model. In this case, it's the model's path. */ + public String getUID() { + return properties.modelPath().toString(); + } + + public String getNickname() { + return properties.nickname(); + } + + public Family getFamily() { + return properties.family(); + } + + public ModelProperties getProperties() { + return properties; } public ObjectDetector load() { - return new RknnObjectDetector(this, inputSize); + return new RknnObjectDetector( + this, new Size(properties.resolutionWidth(), properties.resolutionHeight())); } } diff --git a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java index 572e08d6b..b62869ec1 100644 --- a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java +++ b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipeline.java @@ -56,7 +56,10 @@ public class ObjectDetectionPipeline @Override protected void setPipeParamsImpl() { Optional selectedModel = - NeuralNetworkModelManager.getInstance().getModel(settings.model); + settings.model != null + ? NeuralNetworkModelManager.getInstance() + .getModel(settings.model.modelPath().toString()) + : Optional.empty(); // If the desired model couldn't be found, log an error and try to use the default model if (selectedModel.isEmpty()) { diff --git a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java index ba1fcaf64..67428559a 100644 --- a/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java +++ b/photon-core/src/main/java/org/photonvision/vision/pipeline/ObjectDetectionPipelineSettings.java @@ -18,12 +18,13 @@ package org.photonvision.vision.pipeline; import org.photonvision.common.configuration.NeuralNetworkModelManager; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager; import org.photonvision.vision.objects.Model; public class ObjectDetectionPipelineSettings extends AdvancedPipelineSettings { public double confidence; public double nms; // non maximal suppression - public String model; + public NeuralNetworkPropertyManager.ModelProperties model; public ObjectDetectionPipelineSettings() { super(); @@ -35,6 +36,9 @@ public class ObjectDetectionPipelineSettings extends AdvancedPipelineSettings { confidence = .9; nms = .45; model = - NeuralNetworkModelManager.getInstance().getDefaultModel().map(Model::getName).orElse(""); + NeuralNetworkModelManager.getInstance() + .getDefaultModel() + .map(Model::getProperties) + .orElse(null); } } diff --git a/photon-core/src/main/java/org/photonvision/vision/processes/VisionModuleChangeSubscriber.java b/photon-core/src/main/java/org/photonvision/vision/processes/VisionModuleChangeSubscriber.java index 7a5da5331..1b40864ff 100644 --- a/photon-core/src/main/java/org/photonvision/vision/processes/VisionModuleChangeSubscriber.java +++ b/photon-core/src/main/java/org/photonvision/vision/processes/VisionModuleChangeSubscriber.java @@ -17,12 +17,15 @@ package org.photonvision.vision.processes; +import com.fasterxml.jackson.databind.ObjectMapper; import edu.wpi.first.math.Pair; import java.util.ArrayList; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.concurrent.locks.ReentrantLock; import org.opencv.core.Point; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager.ModelProperties; import org.photonvision.common.dataflow.DataChangeSubscriber; import org.photonvision.common.dataflow.events.DataChangeEvent; import org.photonvision.common.dataflow.events.IncomingWebSocketEvent; @@ -55,7 +58,8 @@ public class VisionModuleChangeSubscriber extends DataChangeSubscriber { @Override public void onDataChangeEvent(DataChangeEvent event) { - // Camera index -1 means a "multicast event" (i.e. the event is received by all cameras) + // Camera index -1 means a "multicast event" (i.e. the event is received by all + // cameras) if (event instanceof IncomingWebSocketEvent wsEvent && wsEvent.cameraUniqueName != null && wsEvent.cameraUniqueName.equals(parentModule.uniqueName())) { @@ -289,6 +293,11 @@ public class VisionModuleChangeSubscriber extends DataChangeSubscriber { } else { propField.setBoolean(currentSettings, (Boolean) newPropValue); } + } else if (propField.getType() == ModelProperties.class + && newPropValue instanceof LinkedHashMap) { + ObjectMapper mapper = new ObjectMapper(); + ModelProperties modelProps = mapper.convertValue(newPropValue, ModelProperties.class); + propField.set(currentSettings, modelProps); } else { propField.set(currentSettings, newPropValue); } diff --git a/photon-core/src/test/java/org/photonvision/vision/pipeline/ObjectDetectionTest.java b/photon-core/src/test/java/org/photonvision/vision/pipeline/ObjectDetectionTest.java deleted file mode 100644 index c389e87d6..000000000 --- a/photon-core/src/test/java/org/photonvision/vision/pipeline/ObjectDetectionTest.java +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (C) Photon Vision. - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU General Public License for more details. - * - * You should have received a copy of the GNU General Public License - * along with this program. If not, see . - */ - -package org.photonvision.vision.pipeline; - -import static org.junit.jupiter.api.Assertions.assertArrayEquals; -import static org.junit.jupiter.api.Assertions.assertThrows; - -import java.util.LinkedList; -import java.util.stream.Stream; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; -import org.photonvision.common.configuration.NeuralNetworkModelManager; - -public class ObjectDetectionTest { - private static LinkedList passNames = - new LinkedList( - java.util.Arrays.asList( - new String[] {"note-640-640-yolov5s.rknn", "note-640-640-yolov5s-labels.txt"}, - new String[] {"object-640-640-yolov8n.rknn", "object-640-640-yolov8n-labels.txt"}, - new String[] { - "example_1.2-640-640-yolov5l.rknn", "example_1.2-640-640-yolov5l-labels.txt" - }, - new String[] {"demo_3.5-640-640-yolov8m.rknn", "demo_3.5-640-640-yolov8m-labels.txt"}, - new String[] {"sample-640-640-yolov5x.rknn", "sample-640-640-yolov5x-labels.txt"}, - new String[] { - "test_case-640-640-yolov8s.rknn", "test_case-640-640-yolov8s-labels.txt" - }, - new String[] { - "model_ABC-640-640-yolov5n.rknn", "model_ABC-640-640-yolov5n-labels.txt" - }, - new String[] {"my_model-640-640-yolov8x.rknn", "my_model-640-640-yolov8x-labels.txt"}, - new String[] {"name_1.0-640-640-yolov5n.rknn", "name_1.0-640-640-yolov5n-labels.txt"}, - new String[] { - "valid_name-640-640-yolov8s.rknn", "valid_name-640-640-yolov8s-labels.txt" - }, - new String[] { - "test.model-640-640-yolov5l.rknn", "test.model-640-640-yolov5l-labels.txt" - }, - new String[] { - "case1_test-640-640-yolov8m.rknn", "case1_test-640-640-yolov8m-labels.txt" - }, - new String[] {"A123-640-640-yolov5x.rknn", "A123-640-640-yolov5x-labels.txt"}, - new String[] { - "z_y_test.model-640-640-yolov8n.rknn", "z_y_test.model-640-640-yolov8n-labels.txt" - })); - private static LinkedList parsedPassNames = - new LinkedList( - java.util.Arrays.asList( - new String[] {"note", "640", "640", "yolov5s"}, - new String[] {"object", "640", "640", "yolov8n"}, - new String[] {"example_1.2", "640", "640", "yolov5l"}, - new String[] {"demo_3.5", "640", "640", "yolov8m"}, - new String[] {"sample", "640", "640", "yolov5x"}, - new String[] {"test_case", "640", "640", "yolov8s"}, - new String[] {"model_ABC", "640", "640", "yolov5n"}, - new String[] {"my_model", "640", "640", "yolov8x"}, - new String[] {"name_1.0", "640", "640", "yolov5n"}, - new String[] {"valid_name", "640", "640", "yolov8s"}, - new String[] {"test.model", "640", "640", "yolov5l"}, - new String[] {"case1_test", "640", "640", "yolov8m"}, - new String[] {"A123", "640", "640", "yolov5x"}, - new String[] {"z_y_test.model", "640", "640", "yolov8n"})); - private static LinkedList failNames = - new LinkedList( - java.util.Arrays.asList( - new String[] {"note-yolov5s.rknn", "note-640-640-yolov5s-labels.txt"}, - new String[] {"640-640-yolov8n.rknn", "object-640-640-yolov8n-labels.txt"}, - new String[] {"example_1.2.rknn", "example_1.2-640-640-yolov5l-labels.txt"}, - new String[] {"demo_3.5-640-yolov8m.rknn", "demo_3.5-640-640-yolov8m-labels.txt"}, - new String[] {"sample-640.rknn", "sample-640-640-yolov5x-labels.txt"}, - new String[] {"test_case.txt", "test_case-640-640-yolov8s-labels.txt"}, - new String[] {"model_ABC.onnx", "model_ABC-640-640-yolov5n-labels.txt"}, - new String[] {"my_model", "my_model-640-640-yolov8x-labels.txt"}, - new String[] {"name_1.0-yolov5n.rknn", "wrong-labels.txt"}, - new String[] {"", "valid_name-640-640-yolov8s-labels.txt"}, - new String[] {null, "test.model-640-640-yolov5l-labels.txt"}, - new String[] {"case1_test-640-640-yolov8m.rknn", null}, - new String[] {"A123-640-640.rknn", "different-labels.txt"}, - new String[] {"z_y_test.model", ""})); - - // Test the model name validation for names that ought to pass - @ParameterizedTest - @MethodSource("verifyPassNameProvider") - public void testRKNNVerificationPass(String[] names) { - NeuralNetworkModelManager.verifyRKNNNames(names[0], names[1]); - } - - // // Test the model name validation for names that ought to fail - @ParameterizedTest - @MethodSource("verifyFailNameProvider") - public void testRNNVerificationFail(String[] names) { - assertThrows( - IllegalArgumentException.class, - () -> NeuralNetworkModelManager.verifyRKNNNames(names[0], names[1])); - } - - // Test the model name parsing - @ParameterizedTest - @MethodSource("parseNameProvider") - public void testRKNNNameParsing(String[] expected, String name) { - String[] parsed = NeuralNetworkModelManager.parseRKNNName(name); - assertArrayEquals(expected, parsed); - } - - static Stream verifyPassNameProvider() { - return passNames.stream().map(array -> Arguments.of((Object) array)); - } - - static Stream verifyFailNameProvider() { - return failNames.stream().map(array -> Arguments.of((Object) array)); - } - - static Stream parseNameProvider() { - // return a stream of parsed pass names, and the first element of each pass name - return passNames.stream() - .map(name -> Arguments.of(parsedPassNames.get(passNames.indexOf(name)), name[0])); - } -} diff --git a/photon-server/src/main/java/org/photonvision/Main.java b/photon-server/src/main/java/org/photonvision/Main.java index 2e07c5270..7e527e3f8 100644 --- a/photon-server/src/main/java/org/photonvision/Main.java +++ b/photon-server/src/main/java/org/photonvision/Main.java @@ -269,8 +269,8 @@ public class Main { logger.info("Loading ML models..."); var modelManager = NeuralNetworkModelManager.getInstance(); - modelManager.extractModels(ConfigManager.getInstance().getModelsDirectory()); - modelManager.discoverModels(ConfigManager.getInstance().getModelsDirectory()); + modelManager.extractModels(); + modelManager.discoverModels(); logger.debug("Loading HardwareManager..."); // Force load the hardware manager diff --git a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java index 6f7bd04c0..b00331558 100644 --- a/photon-server/src/main/java/org/photonvision/server/RequestHandler.java +++ b/photon-server/src/main/java/org/photonvision/server/RequestHandler.java @@ -25,8 +25,10 @@ import java.io.*; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedList; import java.util.Optional; import javax.imageio.ImageIO; import org.apache.commons.io.FileUtils; @@ -37,6 +39,7 @@ import org.opencv.imgcodecs.Imgcodecs; import org.photonvision.common.configuration.ConfigManager; import org.photonvision.common.configuration.NetworkConfig; import org.photonvision.common.configuration.NeuralNetworkModelManager; +import org.photonvision.common.configuration.NeuralNetworkPropertyManager.ModelProperties; import org.photonvision.common.dataflow.DataChangeDestination; import org.photonvision.common.dataflow.DataChangeService; import org.photonvision.common.dataflow.events.IncomingWebSocketEvent; @@ -50,6 +53,7 @@ import org.photonvision.common.logging.Logger; import org.photonvision.common.networking.NetworkManager; import org.photonvision.common.util.ShellExec; import org.photonvision.common.util.TimedTaskManager; +import org.photonvision.common.util.file.JacksonUtils; import org.photonvision.common.util.file.ProgramDirectoryUtilities; import org.photonvision.vision.calibration.CameraCalibrationCoefficients; import org.photonvision.vision.camera.CameraQuirk; @@ -550,55 +554,108 @@ public class RequestHandler { public static void onImportObjectDetectionModelRequest(Context ctx) { try { // Retrieve the uploaded files - var modelFile = ctx.uploadedFile("rknn"); - var labelsFile = ctx.uploadedFile("labels"); + var modelFile = ctx.uploadedFile("modelFile"); - if (modelFile == null || labelsFile == null) { + // Strip any whitespaces on either side of the commas + LinkedList labels = new LinkedList<>(); + String rawLabels = ctx.formParam("labels"); + if (rawLabels != null) { + for (String label : rawLabels.split(",")) { + labels.add(label.trim()); + } + } + int width = Integer.parseInt(ctx.formParam("width")); + int height = Integer.parseInt(ctx.formParam("height")); + NeuralNetworkModelManager.Version version = + switch (ctx.formParam("version").toString()) { + case "YOLOv5" -> NeuralNetworkModelManager.Version.YOLOV5; + case "YOLOv8" -> NeuralNetworkModelManager.Version.YOLOV8; + case "YOLO11" -> NeuralNetworkModelManager.Version.YOLOV11; + // Add more versions as necessary for new models + default -> { + ctx.status(400); + ctx.result("The provided version was not valid"); + logger.error("The provided version was not valid"); + yield null; + } + }; + + if (modelFile == null) { ctx.status(400); ctx.result( - "No File was sent with the request. Make sure that the model and labels files are sent at the keys 'rknn' and 'labels'"); + "No File was sent with the request. Make sure that the model file is sent at the key 'modelFile'"); logger.error( - "No File was sent with the request. Make sure that the model and labels files are sent at the keys 'rknn' and 'labels'"); + "No File was sent with the request. Make sure that the model file is sent at the key 'modelFile'"); return; } - if (!modelFile.extension().contains("rknn") || !labelsFile.extension().contains("txt")) { + if (labels == null || labels.isEmpty()) { ctx.status(400); - ctx.result( - "The uploaded files were not of type 'rknn' and 'txt'. The uploaded files should be a .rknn and .txt file."); - logger.error( - "The uploaded files were not of type 'rknn' and 'txt'. The uploaded files should be a .rknn and .txt file."); + ctx.result("The provided labels were malformed"); + logger.error("The provided labels were malformed"); return; } - // verify naming convention + if (width < 0 || height < 0 || width != Math.floor(width) || height != Math.floor(height)) { + ctx.status(400); + ctx.result( + "The provided width and height were malformed. They must be integers greater than one."); + logger.error( + "The provided width and height were malformed. They must be integers greater than one."); + return; + } - // throws IllegalArgumentException if the model name is invalid - NeuralNetworkModelManager.verifyRKNNNames(modelFile.filename(), labelsFile.filename()); + // If adding additional platforms, check platform matches + if (!modelFile.extension().contains("rknn")) { + ctx.status(400); + ctx.result( + "The uploaded file was not of type 'rknn'. The uploaded file should be a .rknn file."); + logger.error( + "The uploaded file was not of type 'rknn'. The uploaded file should be a .rknn file."); + return; + } - // TODO move into neural network manager - - var modelPath = + Path modelPath = Paths.get( ConfigManager.getInstance().getModelsDirectory().toString(), modelFile.filename()); - var labelsPath = - Paths.get( - ConfigManager.getInstance().getModelsDirectory().toString(), labelsFile.filename()); + + if (modelPath.toFile().exists()) { + ctx.status(400); + ctx.result( + "The model file already exists. Please delete the existing model file before uploading a new one."); + logger.error( + "The model file already exists. Please delete the existing model file before uploading a new one."); + return; + } try (FileOutputStream out = new FileOutputStream(modelPath.toFile())) { modelFile.content().transferTo(out); } - try (FileOutputStream out = new FileOutputStream(labelsPath.toFile())) { - labelsFile.content().transferTo(out); - } + ConfigManager.getInstance() + .getConfig() + .neuralNetworkPropertyManager() + .addModelProperties( + new ModelProperties( + modelPath, + modelFile.filename().replaceAll(".rknn", ""), + labels, + width, + height, + NeuralNetworkModelManager.Family.RKNN, // This can be determined by platform if + // additional platforms are + // supported + version)); - NeuralNetworkModelManager.getInstance() - .discoverModels(ConfigManager.getInstance().getModelsDirectory()); + logger.debug( + ConfigManager.getInstance().getConfig().neuralNetworkPropertyManager().toString()); + + NeuralNetworkModelManager.getInstance().discoverModels(); ctx.status(200).result("Successfully uploaded object detection model"); } catch (Exception e) { ctx.status(500).result("Error processing files: " + e.getMessage()); + logger.error("Error processing new object detection model", e); } DataChangeService.getInstance() @@ -608,6 +665,258 @@ public class RequestHandler { UIPhotonConfiguration.programStateToUi(ConfigManager.getInstance().getConfig()))); } + public static void onExportObjectDetectionModelsRequest(Context ctx) { + logger.info("Exporting Object Detection Models to ZIP Archive"); + + try { + var zip = ConfigManager.getInstance().getObjectDetectionExportAsZip(); + var stream = new FileInputStream(zip); + logger.info("Uploading object detection models with size " + stream.available()); + + ctx.contentType("application/zip"); + ctx.header( + "Content-Disposition", + "attachment; filename=\"photonvision-object-detection-models-export.zip\""); + + ctx.result(stream); + ctx.status(200); + } catch (IOException e) { + logger.error("Unable to export object detection models archive, bad recode from zip to byte"); + ctx.status(500); + ctx.result("There was an error while exporting the object detection models archive"); + } + } + + public static void onExportIndividualObjectDetectionModelRequest(Context ctx) { + logger.info("Exporting Individual Object Detection Model"); + + try { + String modelPath = ctx.queryParam("modelPath"); + + if (modelPath == null || modelPath.isEmpty()) { + ctx.status(400); + ctx.result("The provided model path was malformed"); + logger.error("The provided model path was malformed"); + return; + } + + File modelFile = NeuralNetworkModelManager.getInstance().exportSingleModel(modelPath); + + var stream = new FileInputStream(modelFile); + logger.info("Uploading object detection model with size " + stream.available()); + + ctx.contentType("application/octet-stream"); + ctx.header("Content-Disposition", "attachment; filename=" + modelFile.getName()); + + ctx.result(stream); + ctx.status(200); + } catch (IOException e) { + logger.error("Unable to export object detection model, " + e); + ctx.status(500); + ctx.result("There was an error while exporting the object detection model"); + } + } + + public static void onBulkImportObjectDetectionModelRequest(Context ctx) { + var file = ctx.uploadedFile("data"); + + if (file == null) { + ctx.status(400); + ctx.result( + "No File was sent with the request. Make sure that the object detection zip is sent at the key 'data'"); + logger.error( + "No File was sent with the request. Make sure that the object detection zip file is sent at the key 'data'"); + return; + } + + if (!file.extension().contains("zip")) { + ctx.status(400); + ctx.result( + "The uploaded file was not of type 'zip'. The uploaded file should be a .zip file."); + logger.error( + "The uploaded file was not of type 'zip'. The uploaded file should be a .zip file."); + return; + } + + // Create a temp file + var tempFilePath = handleTempFileCreation(file); + + if (tempFilePath.isEmpty()) { + ctx.status(500); + ctx.result("There was an error while creating a temporary copy of the file"); + logger.error("There was an error while creating a temporary copy of the file"); + return; + } + + Path tempDir = null; + // Extract .rknn files from zip and move to models directory + try { + tempDir = Files.createTempDirectory("photonvision-od-models"); + ZipUtil.unpack(tempFilePath.get(), tempDir.toFile()); + + Path targetModelsDir = ConfigManager.getInstance().getModelsDirectory().toPath(); + + // Copy all files from the source models directory to the target models + // directory + try (var stream = Files.list(tempDir)) { + for (Path modelFile : stream.toList()) { + if (Files.isRegularFile(modelFile) + && !modelFile.getFileName().toString().endsWith(".json")) { + logger.debug("Copying model file: " + modelFile.getFileName()); + Files.copy( + modelFile, + Path.of(targetModelsDir.toString(), modelFile.getFileName().toString()), + StandardCopyOption.REPLACE_EXISTING); + } + } + } + logger.info("Successfully copied models from " + tempDir + " to " + targetModelsDir); + + } catch (Exception e) { + ctx.status(500); + ctx.result("There was an error while extracting and coyping the object detection models"); + logger.error( + "There was an error while extracting and copying the object detection models", e); + return; + } + + if (ConfigManager.getInstance() + .saveUploadedNeuralNetworkProperties( + Path.of(tempDir.toString(), "photonvision-object-detection-models.json"))) { + ctx.status(200); + ctx.result("Successfully saved the uploaded object detection models, rebooting..."); + logger.info("Successfully saved the uploaded object detection models, rebooting..."); + restartProgram(); + } else { + ctx.status(500); + ctx.result("There was an error while saving the uploaded object detection models"); + logger.error("There was an error while saving the uploaded object detection models"); + } + } + + private record DeleteObjectDetectionModelRequest(String modelPath) {} + + public static void onDeleteObjectDetectionModelRequest(Context ctx) { + logger.info("Deleting object detection model"); + Path modelPath; + + try { + DeleteObjectDetectionModelRequest request = + JacksonUtils.deserialize(ctx.body(), DeleteObjectDetectionModelRequest.class); + + modelPath = Path.of(request.modelPath.substring(7)); + + if (modelPath == null) { + ctx.status(400); + ctx.result("The provided model path was malformed"); + logger.error("The provided model path was malformed"); + return; + } + + if (!modelPath.toFile().exists()) { + ctx.status(400); + ctx.result("The provided model path does not exist"); + logger.error("The provided model path does not exist"); + return; + } + + if (!modelPath.toFile().delete()) { + ctx.status(500); + ctx.result("Unable to delete the model file"); + logger.error("Unable to delete the model file"); + return; + } + + if (!ConfigManager.getInstance() + .getConfig() + .neuralNetworkPropertyManager() + .removeModel(modelPath)) { + ctx.status(400); + ctx.result("The model's information was not found in the config"); + logger.error("The model's information was not found in the config"); + return; + } + + NeuralNetworkModelManager.getInstance().discoverModels(); + + ctx.status(200).result("Successfully deleted object detection model"); + + } catch (Exception e) { + ctx.status(500); + ctx.result("Error deleting object detection model: " + e.getMessage()); + logger.error("Error deleting object detection model", e); + } + + DataChangeService.getInstance() + .publishEvent( + new OutgoingUIEvent<>( + "fullsettings", + UIPhotonConfiguration.programStateToUi(ConfigManager.getInstance().getConfig()))); + } + + private record RenameObjectDetectionModelRequest(String modelPath, String newName) {} + + public static void onRenameObjectDetectionModelRequest(Context ctx) { + try { + RenameObjectDetectionModelRequest request = + JacksonUtils.deserialize(ctx.body(), RenameObjectDetectionModelRequest.class); + + Path modelPath = Path.of(request.modelPath); + + if (modelPath == null) { + ctx.status(400); + ctx.result("The provided model path was malformed"); + logger.error("The provided model path was malformed"); + return; + } + + if (!modelPath.toFile().exists()) { + ctx.status(400); + ctx.result("The provided model path does not exist"); + logger.error("The model path: " + modelPath + " does not exist"); + return; + } + + if (request.newName == null || request.newName.isEmpty()) { + ctx.status(400); + ctx.result("The provided new name was malformed"); + logger.error("The provided new name was malformed"); + return; + } + + if (!ConfigManager.getInstance() + .getConfig() + .neuralNetworkPropertyManager() + .renameModel(modelPath, request.newName)) { + ctx.status(400); + ctx.result("The model's information was not found in the config"); + logger.error("The model's information was not found in the config"); + return; + } + + NeuralNetworkModelManager.getInstance().discoverModels(); + ctx.status(200).result("Successfully renamed object detection model"); + } catch (Exception e) { + ctx.status(500); + ctx.result("Error renaming object detection model: " + e.getMessage()); + logger.error("Error renaming object detection model", e); + return; + } + } + + public static void onNukeObjectDetectionModelsRequest(Context ctx) { + logger.info("Attempting to clear object detection models"); + try { + NeuralNetworkModelManager.getInstance().clearModels(); + NeuralNetworkModelManager.getInstance().extractModels(); + ctx.status(200).result("Successfully cleared and reset object detection models"); + } catch (Exception e) { + ctx.status(500); + ctx.result("Error clearing object detection models: " + e.getMessage()); + logger.error("Error clearing object detection models", e); + } + } + public static void onDeviceRestartRequest(Context ctx) { ctx.status(HardwareManager.getInstance().restartDevice() ? 204 : 500); } diff --git a/photon-server/src/main/java/org/photonvision/server/Server.java b/photon-server/src/main/java/org/photonvision/server/Server.java index ef24a6da2..c50fab873 100644 --- a/photon-server/src/main/java/org/photonvision/server/Server.java +++ b/photon-server/src/main/java/org/photonvision/server/Server.java @@ -127,9 +127,6 @@ public class Server { // Utilities app.post("/api/utils/offlineUpdate", RequestHandler::onOfflineUpdateRequest); - app.post( - "/api/utils/importObjectDetectionModel", - RequestHandler::onImportObjectDetectionModelRequest); app.get("/api/utils/photonvision-journalctl.txt", RequestHandler::onLogExportRequest); app.post("/api/utils/restartProgram", RequestHandler::onProgramRestartRequest); app.post("/api/utils/restartDevice", RequestHandler::onDeviceRestartRequest); @@ -147,6 +144,18 @@ public class Server { app.post("/api/calibration/end", RequestHandler::onCalibrationEndRequest); app.post("/api/calibration/importFromData", RequestHandler::onDataCalibrationImportRequest); + // Object detection + app.post("/api/objectdetection/import", RequestHandler::onImportObjectDetectionModelRequest); + app.post( + "/api/objectdetection/bulkimport", RequestHandler::onBulkImportObjectDetectionModelRequest); + app.get("/api/objectdetection/export", RequestHandler::onExportObjectDetectionModelsRequest); + app.get( + "/api/objectdetection/exportIndividual", + RequestHandler::onExportIndividualObjectDetectionModelRequest); + app.post("/api/objectdetection/delete", RequestHandler::onDeleteObjectDetectionModelRequest); + app.post("/api/objectdetection/rename", RequestHandler::onRenameObjectDetectionModelRequest); + app.post("/api/objectdetection/nuke", RequestHandler::onNukeObjectDetectionModelsRequest); + app.start(port); } diff --git a/photon-server/src/main/resources/models/algaeV1-640-640-yolov8n-labels.txt b/photon-server/src/main/resources/models/algaeV1-640-640-yolov8n-labels.txt deleted file mode 100644 index 51335db7e..000000000 --- a/photon-server/src/main/resources/models/algaeV1-640-640-yolov8n-labels.txt +++ /dev/null @@ -1 +0,0 @@ -algae