diff --git a/cog.yaml b/cog.yaml new file mode 100644 index 0000000..0f3f3b7 --- /dev/null +++ b/cog.yaml @@ -0,0 +1,25 @@ +# Configuration for Cog ⚙️ +# Reference: https://cog.run/yaml + +build: + # set to true if your model requires a GPU + gpu: True + + # a list of ubuntu apt packages to install + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + + # python version in the form '3.11' or '3.11.4' + python_version: "3.11" + + # a list of packages in the format == + python_packages: + - "solovision" + - "git+https://github.com/GetSoloTech/solo-ultralytics" + + # commands run after the environment is setup + #run: + +# predict.py defines how predictions are run on your model +predict: "predict.py:Predictor" diff --git a/predict.py b/predict.py new file mode 100644 index 0000000..eae09d4 --- /dev/null +++ b/predict.py @@ -0,0 +1,158 @@ +# Prediction interface for Cog ⚙️ +# https://cog.run/python + +import cv2 +import os +import torch +import tempfile +import numpy as np + +from PIL import Image +from typing import Iterator, Any, List +from ultralytics import YOLO +from solovision import ByteTracker +from cog import BasePredictor, Input, Path + + +class Predictor(BasePredictor): + MODEL_WEIGHTS = "checkpoints/yolov8l.pt" + REID_WEIGHTS = "checkpoints/osnet_x1_0_msmt17.pt" + + def setup(self) -> None: + """Load the model into memory to make running multiple predictions efficient""" + print("Starting model setup...") + self.device = "cuda:0" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu" + print(f"Using device: {self.device}") + self.model = YOLO(self.MODEL_WEIGHTS) + + # Ensures output directory exists + os.makedirs("outputs", exist_ok=True) + + def plot_detections(self, image: np.ndarray, tracks) -> np.ndarray: + input_is_pil = isinstance(image, Image.Image) + line_width = max(round(sum(image.size if input_is_pil else image.shape) / 2 * 0.001), 2) + font_th = max(line_width - 1, 2) + font_scale = line_width / 2 + label_padding = 7 + color = (37, 4, 11) # BGR format + + for track in tracks: + bbox = track[:4].astype(int) + track_id = int(track[4]) + # Convert bbox to integers + x_min, y_min, x_max, y_max = map(int, bbox) + + # Draw bounding box + cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, thickness=font_th, lineType=cv2.LINE_AA) + + # Prepare label with confidence score and ID + label = f"{track_id}" + label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=font_th) + label_x = max(x_min, 0) + label_y = max(y_min - label_size[1] - label_padding, 0) + + # Draw label background with padding + cv2.rectangle( + image, + (label_x, label_y), + (label_x + label_size[0] + 2 * label_padding, label_y + label_size[1] + 2 * label_padding), + color, + thickness=-1, + lineType=cv2.LINE_AA, + ) + + # Draw label text + cv2.putText( + image, + label, + (label_x + label_padding, label_y + label_size[1] + label_padding), + cv2.FONT_HERSHEY_SIMPLEX, + font_scale, + (255, 255, 255), + thickness=font_th, + lineType=cv2.LINE_AA, + ) + + return image + + def predict( + self, + video: Path = Input(description="Video Input"), + conf: float = Input( + description="Confidence Threshold", default=0.2 + ), + iou: float = Input( + description="NMS Threshold", default=0.75 + ), + match_thresh: float = Input( + description="Id Matching Threshold", default=0.8 + ), + track_buffer: int = Input( + description="Lost Tracks are held up to this time(300 Frames) before getting deleted", default=300 + ), + with_reid: bool = Input( + description="Use REID for feature matching", default=True + ), + appearance_thresh: float = Input( + description="REID Appearance Matching Distance(Works when with_reid = True)", default=0.35 + ) + ) -> Path: + + self.tracker = ByteTracker( + with_reid=with_reid, + reid_weights=Path(self.REID_WEIGHTS), + device=self.device, + half=False, + track_buffer=track_buffer, + match_thresh=match_thresh, + appearance_thresh=appearance_thresh, + ) + + output_dir = Path("outputs") + output_dir.mkdir(exist_ok=True) + temp_dir = tempfile.mkdtemp() + + """Run inference on the model""" + print(f"- Confidence threshold: {conf}") + print(f"- IOU threshold: {iou}") + + params = { + 'source': str(video), + 'conf': conf, + 'iou': iou, + 'stream': True, + 'classes': 0 + } + # Get video properties for output + cap = cv2.VideoCapture(str(video)) + width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + fps = int(cap.get(cv2.CAP_PROP_FPS)) + cap.release() + + # model predictions + results = self.model.predict(**params) + frame_idx = 0 + print("[*] Processing frames...") + for result in results: + frame_idx += 1 + dets = result.boxes.data.cpu().numpy() + frame = result.orig_img + tracks = self.tracker.update(dets, frame) + + # Draw tracks on frame if they exist + if len(tracks) > 0: + frame = self.plot_detections(frame, tracks) + + frame_path = os.path.join(temp_dir, f"frame_{frame_idx:05d}.png") + cv2.imwrite(frame_path, frame) + + print("[*] Generating outputs...") + # Generate output video + video_name = os.path.basename(str(video)) + output_path = output_dir/ "output_video.mp4" + frames_pattern = os.path.join(temp_dir, "frame_%05d.png") + ffmpeg_cmd = f"ffmpeg -y -r {fps} -i {frames_pattern} -c:v libx264 -pix_fmt yuv420p {output_path}" + os.system(ffmpeg_cmd) + + return output_path \ No newline at end of file diff --git a/solovision/configs/bytetrack.yaml b/solovision/configs/bytetrack.yaml index 20a7505..a2ef361 100644 --- a/solovision/configs/bytetrack.yaml +++ b/solovision/configs/bytetrack.yaml @@ -18,6 +18,11 @@ track_buffer: default: 300 # from the default parameters range: [20, 81] +min_hits: + type: randint + default: 6 # from the default parameters + range: [1, 30] + match_thresh: type: uniform default: 0.8 # from the default parameters diff --git a/solovision/track.py b/solovision/track.py deleted file mode 100644 index 7f9f77c..0000000 --- a/solovision/track.py +++ /dev/null @@ -1,98 +0,0 @@ -import argparse -import cv2 -import numpy as np -from functools import partial -from pathlib import Path - -import torch -from solovision.tracker_zoo import create_tracker -from solovision.utils import ROOT, WEIGHTS, TRACKER_CONFIGS -from solovision.utils.checks import RequirementsChecker - -from solovision.detectors import get_yolo_inferer -from ultralytics import YOLO - - -checker = RequirementsChecker() -checker.check_packages(('ultralytics @ git+https://github.com/zeeshaan28/solo-ultralytics.git', )) # install - - -def on_predict_start(predictor, persist=False): - """ - Initialize trackers for object tracking during prediction. - - Args: - predictor (object): The predictor object to initialize trackers for. - persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False. - """ - - tracking_config = TRACKER_CONFIGS / 'bytetrack.yaml' - trackers = [] - for i in range(predictor.dataset.bs): - tracker = create_tracker( - tracking_config, - predictor.custom_args.reid_model, - predictor.device, - predictor.custom_args.half, - predictor.custom_args.per_class - ) - # motion only modeles do not have - if hasattr(tracker, 'model'): - tracker.model.warmup() - trackers.append(tracker) - - predictor.trackers = trackers - - -@torch.no_grad() -def run(args): - - ul_models = ['yolov8', 'yolov9', 'yolov10', 'yolo11', 'rtdetr', 'sam'] - - yolo = YOLO( - args.yolo_model if any(yolo in str(args.yolo_model) for yolo in ul_models) else 'yolov8n.pt', - ) - - results = yolo.track( - source=args.source, - conf=args.conf, - iou=args.iou, - agnostic_nms=args.agnostic_nms, - show=True, - stream=True, - save_crop= args.save_crops, - device=args.device, - show_conf=args.show_conf, - save_txt=args.save_txt, - show_labels=args.show_labels, - save=args.save, - verbose=args.verbose, - exist_ok=args.exist_ok, - project=args.project, - name=args.name, - classes=args.classes, - imgsz=args.imgsz, - vid_stride=args.vid_stride, - line_width=args.line_width - ) - - yolo.add_callback('on_predict_start', partial(on_predict_start, persist=True)) - - if not any(yolo in str(args.yolo_model) for yolo in ul_models): - # replace yolov8 model - m = get_yolo_inferer(args.yolo_model) - model = m( - model=args.yolo_model, - device=yolo.predictor.device, - args=yolo.predictor.args - ) - yolo.predictor.model = model - - # store custom args in predictor - yolo.predictor.custom_args = args - - for r in results: - if args.show is True: - if cv2.waitKey(1) & 0xFF in (ord(' '), ord('q')): - break - diff --git a/solovision/trackers/basetracker.py b/solovision/trackers/basetracker.py index cde8cd1..a4571b3 100644 --- a/solovision/trackers/basetracker.py +++ b/solovision/trackers/basetracker.py @@ -1,12 +1,9 @@ import numpy as np import cv2 as cv -import hashlib -import colorsys from abc import ABC, abstractmethod from solovision.utils import logger as LOGGER from solovision.utils.iou import AssociationFunction - class BaseTracker(ABC): def __init__( self, @@ -154,5 +151,3 @@ def check_inputs(self, dets, img): assert ( dets.shape[1] == 6 ), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6" - - \ No newline at end of file diff --git a/solovision/trackers/bytetrack/bytetracker.py b/solovision/trackers/bytetrack/bytetracker.py index b6d714f..f2617a8 100644 --- a/solovision/trackers/bytetrack/bytetracker.py +++ b/solovision/trackers/bytetrack/bytetracker.py @@ -1,6 +1,7 @@ import torch import numpy as np from pathlib import Path +from typing import Optional from solovision.motion.kalman_filters.xywh_kf import KalmanFilterXYWH from solovision.appearance.reid_auto_backend import ReidAutoBackend @@ -14,7 +15,6 @@ from solovision.motion.cmc import get_cmc_method - class ByteTracker(BaseTracker): """ ByteTracker: State-of-the-art Multi Object Tracking Algorithm. @@ -35,17 +35,17 @@ class ByteTracker(BaseTracker): fuse_first_associate (bool, optional): Fuse appearance and motion in the first association step. with_reid (bool, optional): Use ReID features for association. """ - def __init__( self, - reid_weights: Path, + reid_weights: Optional[Path], device: torch.device, half: bool, per_class: bool = False, track_high_thresh: float = 0.5, track_low_thresh: float = 0.1, - new_track_thresh: float = 0.6, + new_track_thresh: float = 0.5, track_buffer: int = 30, + min_hits: int = 6, match_thresh: float = 0.8, proximity_thresh: float = 0.5, appearance_thresh: float = 0.25, @@ -64,6 +64,7 @@ def __init__( self.track_low_thresh = track_low_thresh self.new_track_thresh = new_track_thresh self.match_thresh = match_thresh + self.min_hits = min_hits self.buffer_size = int(frame_rate / 30.0 * track_buffer) self.max_time_lost = self.buffer_size @@ -182,7 +183,7 @@ def _first_association(self, dets, dets_first, active_tracks, unconfirmed, img, track = strack_pool[itracked] det = detections[idet] if track.state == TrackState.Tracked: - track.update(detections[idet], self.frame_count) + track.update(detections[idet], self.frame_count, self.min_hits) activated_stracks.append(track) else: track.re_activate(det, self.frame_count, new_id=False) @@ -209,7 +210,7 @@ def _second_association(self, dets_second, activated_stracks, lost_stracks, refi track = r_tracked_stracks[itracked] det = detections_second[idet] if track.state == TrackState.Tracked: - track.update(det, self.frame_count) + track.update(det, self.frame_count, self.min_hits) activated_stracks.append(track) else: track.re_activate(det, self.frame_count, new_id=False) @@ -258,7 +259,7 @@ def _handle_unconfirmed_tracks(self, u_detection, detections, activated_stracks, # Update matched unconfirmed tracks for itracked, idet in matches: - unconfirmed[itracked].update(detections[idet], self.frame_count) + unconfirmed[itracked].update(detections[idet], self.frame_count, self.min_hits) activated_stracks.append(unconfirmed[itracked]) # Mark unmatched unconfirmed tracks as removed diff --git a/solovision/trackers/bytetrack/strack.py b/solovision/trackers/bytetrack/strack.py index c5c8d66..3816a90 100644 --- a/solovision/trackers/bytetrack/strack.py +++ b/solovision/trackers/bytetrack/strack.py @@ -119,11 +119,9 @@ def re_activate(self, new_track, frame_id, new_id=False): self.update_features(new_track.curr_feat) self.tracklet_len = 0 self.state = TrackState.Tracked - self.frame_id = frame_id if new_id: self.activation_id = self.activation() - if self.activation_id is not None: self.is_activated = True self.is_activated = True @@ -132,7 +130,7 @@ def re_activate(self, new_track, frame_id, new_id=False): self.det_ind = new_track.det_ind self.update_cls(new_track.cls, new_track.conf) - def update(self, new_track, frame_id, min_hits = 6): + def update(self, new_track, frame_id, min_hits): """Update the current track with a matched detection.""" self.frame_id = frame_id self.tracklet_len += 1