参考文档:https://github.com/ultralytics/ultralytics/blob/main/README.zh-CN.md
代码仓库的名字不再沿用 yolovx 而是使用 ultralytics, 而这个名字正是创建该项目的公司的名字,之所以如此,一方面是该公司想要创建一个 CV 的通用仓库,使其能够支持大部分的 CV 任务,如 物体检测与跟踪、实例分割、图像分类和姿态估计等,以区分于之前只是检测任务使用的 yolovx。第二方面,估计有想提升公司知名度的意味,故命名为自己公司的名字。
一,基本核心代码
from ultralytics import YOLO model = YOLO("xxxx.pt")# 或者(用于训练): model = YOLO("yolov8x.yaml") image = "xxx.jpg"# 或者图片文件夹 model.predict(image, save=True)# 返回image的预测结果 # 训练:model.train(data="数据集路径.yaml", epochs=200, batch=16) # 训练数据集类型看yolov8训练流程
二,代码扩展案例
1,训练
from ultralytics import YOLO model = YOLO("./weights/yolov8n.pt") data = "./dataset/car/mydata.yaml" model.train(data=data, epochs=100, batch=1)
2,预测
2.1单图
from ultralytics import YOLO save_path = './' image_path = './dataset/fire_smoke/000010.jpg' model = YOLO('./weights/best.pt') # 单图预测 results = model.predict(image_path) for r in result[0]: if r.boxes.cls.item()==0.0: print('有火') elif r.boxes.cls.item()==1.0: print('有烟')
2.2多图文件夹
from ultralytics import YOLO from pathlib import Path save_path = './' images_path = './dataset/fire_smoke/images' model = YOLO('./weights/best.pt') for path in Path(images_path).glob('*.*'): results = model.predict(str(path)) for result results: for r in result: if r.boxes.cls.item()==0.0: print('有火') elif r.boxes.cls.item()==1.0: print('有烟')
2.3 图片路径文件
from ultralytics import YOLO from pathlib import Path save_path = './' txt_path = './xxx.txt'# txt内容是图片路径 model = YOLO('./weights/best.pt') # 图集推理 with open(txt_path, 'r', encoding='utf-8') as f: lines = f.readlines() for path in lines: results = model.predict(path[:-1]) for r in result[0]: if r.boxes.cls.item()==0.0: print('有火') elif r.boxes.cls.item()==1.0: print('有烟')
2.4网络摄像头
from ultralytics import YOLO import cv2 save_path = './' video_url = 0# 网络摄像头路径,0表示本机摄像头 model = YOLO('./weights/best.pt') cap = cv2.VideoCapture(video_url) while True: ret, frame = cap.read() results = model.predict(frame, save=True) img = cv2.imread('./predict/image0.jpg') cv2.imshow('img', img) if cv2.waitKey(1) == ord('q'): break for r in results[0]: # if r.boxes.conf.item()>5.0:# 置信度阈值 if r.boxes.cls.item()==0.0: print('有火') elif r.boxes.cls.item()==1.0: print('有烟')