mirror of
https://github.com/PhotonVision/photonvision
synced 2026-06-23 01:21:40 +00:00
RKNN conversion tool (#2024)
This commit is contained in:
174
scripts/rknn-convert-tool/create_onnx.py
Normal file
174
scripts/rknn-convert-tool/create_onnx.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import argparse
|
||||
import os.path
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
# This will work for all models that don't use anchors (e.g. all YOLO models except YOLOv5/v7)
|
||||
# This includes YOLOv5u
|
||||
yolo_non_anchor_repo = "https://github.com/airockchip/ultralytics_yolo11"
|
||||
|
||||
# For original YOLOv5 models
|
||||
yolov5_repo = "https://github.com/airockchip/yolov5"
|
||||
|
||||
valid_yolo_versions = ["yolov5", "yolov8", "yolov11"]
|
||||
comma_sep_yolo_versions = ", ".join(valid_yolo_versions)
|
||||
|
||||
ultralytics_folder_name_yolov5 = "airockchip_yolo_pkg_yolov5"
|
||||
ultralytics_default_folder_name = "airockchip_yolo_pkg"
|
||||
|
||||
|
||||
bad_model_msg = """
|
||||
This is usually due to passing in the wrong model version.
|
||||
Please make sure you have the right model version and try again.
|
||||
"""
|
||||
|
||||
|
||||
# idk how else to make Google Colab display this nicely
|
||||
class IncorrectModelError(Exception):
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
def print_bad_model_msg(cause):
|
||||
print(f"{cause}{bad_model_msg}")
|
||||
|
||||
|
||||
def check_git_installed():
|
||||
try:
|
||||
subprocess.run(["git", "--version"]).check_returncode()
|
||||
except:
|
||||
print("Git is not installed or not found in your PATH.")
|
||||
print("Please install Git from https://git-scm.com/downloads and try again.")
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def check_or_clone_rockchip_repo(repo_url, repo_name=ultralytics_default_folder_name):
|
||||
if os.path.exists(repo_name):
|
||||
print(
|
||||
f'Existing Rockchip repo "{repo_name}" detected, skipping installation...'
|
||||
)
|
||||
else:
|
||||
print(f'Cloning Rockchip repo to "{repo_name}"')
|
||||
try:
|
||||
subprocess.run(["git", "clone", repo_url, repo_name]).check_returncode()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Failed to clone Rockchip repo, see error output below")
|
||||
print(e.output)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def run_pip_install_or_else_exit(args):
|
||||
print("Running pip install...")
|
||||
|
||||
try:
|
||||
subprocess.run(["pip", "install"] + args).check_returncode()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Pip install rockchip repo failed, see error output")
|
||||
print(e.output)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def run_onnx_conversion_yolov5(model_path):
|
||||
check_or_clone_rockchip_repo(yolov5_repo, ultralytics_folder_name_yolov5)
|
||||
run_pip_install_or_else_exit(
|
||||
[
|
||||
"-r",
|
||||
os.path.join(ultralytics_folder_name_yolov5, "requirements.txt"),
|
||||
"torch<2.6.0",
|
||||
"onnx",
|
||||
]
|
||||
)
|
||||
|
||||
model_abspath = os.path.abspath(model_path)
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"python",
|
||||
f"{ultralytics_folder_name_yolov5}/export.py",
|
||||
"--weights",
|
||||
model_abspath,
|
||||
"--rknpu",
|
||||
"--include",
|
||||
"onnx",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
).check_returncode()
|
||||
except subprocess.CalledProcessError as e:
|
||||
print("Failed to run YOLOv5 export, see output below")
|
||||
output_string = (e.stdout or "") + (e.stderr or "")
|
||||
print(output_string)
|
||||
|
||||
if "ModuleNotFoundError" in output_string and "ultralytics" in output_string:
|
||||
print_bad_model_msg(
|
||||
"It seems the YOLOv5 repo could not find an ultralytics installation."
|
||||
)
|
||||
elif (
|
||||
"AttributeError" in output_string
|
||||
and "_register_detect_seperate" in output_string
|
||||
):
|
||||
print_bad_model_msg("It seems that you received a model attribute error.")
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def run_onnx_conversion_no_anchor(model_path):
|
||||
check_or_clone_rockchip_repo(yolo_non_anchor_repo)
|
||||
run_pip_install_or_else_exit(["-e", ultralytics_default_folder_name, "onnx"])
|
||||
|
||||
sys.path.insert(0, os.path.abspath(ultralytics_default_folder_name))
|
||||
model_abs_path = os.path.abspath(model_path)
|
||||
|
||||
from ultralytics import YOLO
|
||||
|
||||
try:
|
||||
model = YOLO(model_abs_path)
|
||||
model.export(format="rknn")
|
||||
except TypeError as e:
|
||||
if "originally trained" in str(e):
|
||||
print_bad_model_msg(
|
||||
"Ultralytics has detected that this model is a YOLOv5 model."
|
||||
)
|
||||
else:
|
||||
print(e)
|
||||
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate valid ONNX file for yolo model"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--model_path",
|
||||
required=True,
|
||||
help=(f"Path to YOLO model"),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
"--version",
|
||||
required=True,
|
||||
choices=valid_yolo_versions,
|
||||
help=(f"Model version, must be one of: {comma_sep_yolo_versions}"),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
check_git_installed()
|
||||
|
||||
try:
|
||||
if args.version.lower() == "yolov5":
|
||||
run_onnx_conversion_yolov5(args.model_path)
|
||||
else:
|
||||
run_onnx_conversion_no_anchor(args.model_path)
|
||||
|
||||
print(
|
||||
"Model export finished. Please use the generated ONNX file to convert to RKNN."
|
||||
)
|
||||
except SystemExit:
|
||||
print("Model export failed. Please see output above.")
|
||||
Reference in New Issue
Block a user