Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Configuration for Cog ⚙️
# Reference: https://cog.run/yaml

build:
# set to true if your model requires a GPU
gpu: True

# a list of ubuntu apt packages to install
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"

# python version in the form '3.11' or '3.11.4'
python_version: "3.11"

# a list of packages in the format <package-name>==<version>
python_packages:
- "solovision"
- "git+https://github.com/GetSoloTech/solo-ultralytics"

# commands run after the environment is setup
#run:

# predict.py defines how predictions are run on your model
predict: "predict.py:Predictor"
158 changes: 158 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Prediction interface for Cog ⚙️
# https://cog.run/python

import cv2
import os
import torch
import tempfile
import numpy as np

from PIL import Image
from typing import Iterator, Any, List
from ultralytics import YOLO
from solovision import ByteTracker
from cog import BasePredictor, Input, Path


class Predictor(BasePredictor):
MODEL_WEIGHTS = "checkpoints/yolov8l.pt"
REID_WEIGHTS = "checkpoints/osnet_x1_0_msmt17.pt"

def setup(self) -> None:
"""Load the model into memory to make running multiple predictions efficient"""
print("Starting model setup...")
self.device = "cuda:0" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
print(f"Using device: {self.device}")
self.model = YOLO(self.MODEL_WEIGHTS)

# Ensures output directory exists
os.makedirs("outputs", exist_ok=True)

def plot_detections(self, image: np.ndarray, tracks) -> np.ndarray:
input_is_pil = isinstance(image, Image.Image)
line_width = max(round(sum(image.size if input_is_pil else image.shape) / 2 * 0.001), 2)
font_th = max(line_width - 1, 2)
font_scale = line_width / 2
label_padding = 7
color = (37, 4, 11) # BGR format

for track in tracks:
bbox = track[:4].astype(int)
track_id = int(track[4])
# Convert bbox to integers
x_min, y_min, x_max, y_max = map(int, bbox)

# Draw bounding box
cv2.rectangle(image, (x_min, y_min), (x_max, y_max), color, thickness=font_th, lineType=cv2.LINE_AA)

# Prepare label with confidence score and ID
label = f"{track_id}"
label_size, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, fontScale=font_scale, thickness=font_th)
label_x = max(x_min, 0)
label_y = max(y_min - label_size[1] - label_padding, 0)

# Draw label background with padding
cv2.rectangle(
image,
(label_x, label_y),
(label_x + label_size[0] + 2 * label_padding, label_y + label_size[1] + 2 * label_padding),
color,
thickness=-1,
lineType=cv2.LINE_AA,
)

# Draw label text
cv2.putText(
image,
label,
(label_x + label_padding, label_y + label_size[1] + label_padding),
cv2.FONT_HERSHEY_SIMPLEX,
font_scale,
(255, 255, 255),
thickness=font_th,
lineType=cv2.LINE_AA,
)

return image

def predict(
self,
video: Path = Input(description="Video Input"),
conf: float = Input(
description="Confidence Threshold", default=0.2
),
iou: float = Input(
description="NMS Threshold", default=0.75
),
match_thresh: float = Input(
description="Id Matching Threshold", default=0.8
),
track_buffer: int = Input(
description="Lost Tracks are held up to this time(300 Frames) before getting deleted", default=300
),
with_reid: bool = Input(
description="Use REID for feature matching", default=True
),
appearance_thresh: float = Input(
description="REID Appearance Matching Distance(Works when with_reid = True)", default=0.35
)
) -> Path:

self.tracker = ByteTracker(
with_reid=with_reid,
reid_weights=Path(self.REID_WEIGHTS),
device=self.device,
half=False,
track_buffer=track_buffer,
match_thresh=match_thresh,
appearance_thresh=appearance_thresh,
)

output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
temp_dir = tempfile.mkdtemp()

"""Run inference on the model"""
print(f"- Confidence threshold: {conf}")
print(f"- IOU threshold: {iou}")

params = {
'source': str(video),
'conf': conf,
'iou': iou,
'stream': True,
'classes': 0
}
# Get video properties for output
cap = cv2.VideoCapture(str(video))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
cap.release()

# model predictions
results = self.model.predict(**params)
frame_idx = 0
print("[*] Processing frames...")
for result in results:
frame_idx += 1
dets = result.boxes.data.cpu().numpy()
frame = result.orig_img
tracks = self.tracker.update(dets, frame)

# Draw tracks on frame if they exist
if len(tracks) > 0:
frame = self.plot_detections(frame, tracks)

frame_path = os.path.join(temp_dir, f"frame_{frame_idx:05d}.png")
cv2.imwrite(frame_path, frame)

print("[*] Generating outputs...")
# Generate output video
video_name = os.path.basename(str(video))
output_path = output_dir/ "output_video.mp4"
frames_pattern = os.path.join(temp_dir, "frame_%05d.png")
ffmpeg_cmd = f"ffmpeg -y -r {fps} -i {frames_pattern} -c:v libx264 -pix_fmt yuv420p {output_path}"
os.system(ffmpeg_cmd)

return output_path
5 changes: 5 additions & 0 deletions solovision/configs/bytetrack.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@ track_buffer:
default: 300 # from the default parameters
range: [20, 81]

min_hits:
type: randint
default: 6 # from the default parameters
range: [1, 30]

match_thresh:
type: uniform
default: 0.8 # from the default parameters
Expand Down
98 changes: 0 additions & 98 deletions solovision/track.py

This file was deleted.

5 changes: 0 additions & 5 deletions solovision/trackers/basetracker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import numpy as np
import cv2 as cv
import hashlib
import colorsys
from abc import ABC, abstractmethod
from solovision.utils import logger as LOGGER
from solovision.utils.iou import AssociationFunction


class BaseTracker(ABC):
def __init__(
self,
Expand Down Expand Up @@ -154,5 +151,3 @@ def check_inputs(self, dets, img):
assert (
dets.shape[1] == 6
), "Unsupported 'dets' 2nd dimension lenght, valid lenghts is 6"


15 changes: 8 additions & 7 deletions solovision/trackers/bytetrack/bytetracker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import numpy as np
from pathlib import Path
from typing import Optional

from solovision.motion.kalman_filters.xywh_kf import KalmanFilterXYWH
from solovision.appearance.reid_auto_backend import ReidAutoBackend
Expand All @@ -14,7 +15,6 @@
from solovision.motion.cmc import get_cmc_method



class ByteTracker(BaseTracker):
"""
ByteTracker: State-of-the-art Multi Object Tracking Algorithm.
Expand All @@ -35,17 +35,17 @@ class ByteTracker(BaseTracker):
fuse_first_associate (bool, optional): Fuse appearance and motion in the first association step.
with_reid (bool, optional): Use ReID features for association.
"""

def __init__(
self,
reid_weights: Path,
reid_weights: Optional[Path],
device: torch.device,
half: bool,
per_class: bool = False,
track_high_thresh: float = 0.5,
track_low_thresh: float = 0.1,
new_track_thresh: float = 0.6,
new_track_thresh: float = 0.5,
track_buffer: int = 30,
min_hits: int = 6,
match_thresh: float = 0.8,
proximity_thresh: float = 0.5,
appearance_thresh: float = 0.25,
Expand All @@ -64,6 +64,7 @@ def __init__(
self.track_low_thresh = track_low_thresh
self.new_track_thresh = new_track_thresh
self.match_thresh = match_thresh
self.min_hits = min_hits

self.buffer_size = int(frame_rate / 30.0 * track_buffer)
self.max_time_lost = self.buffer_size
Expand Down Expand Up @@ -182,7 +183,7 @@ def _first_association(self, dets, dets_first, active_tracks, unconfirmed, img,
track = strack_pool[itracked]
det = detections[idet]
if track.state == TrackState.Tracked:
track.update(detections[idet], self.frame_count)
track.update(detections[idet], self.frame_count, self.min_hits)
activated_stracks.append(track)
else:
track.re_activate(det, self.frame_count, new_id=False)
Expand All @@ -209,7 +210,7 @@ def _second_association(self, dets_second, activated_stracks, lost_stracks, refi
track = r_tracked_stracks[itracked]
det = detections_second[idet]
if track.state == TrackState.Tracked:
track.update(det, self.frame_count)
track.update(det, self.frame_count, self.min_hits)
activated_stracks.append(track)
else:
track.re_activate(det, self.frame_count, new_id=False)
Expand Down Expand Up @@ -258,7 +259,7 @@ def _handle_unconfirmed_tracks(self, u_detection, detections, activated_stracks,

# Update matched unconfirmed tracks
for itracked, idet in matches:
unconfirmed[itracked].update(detections[idet], self.frame_count)
unconfirmed[itracked].update(detections[idet], self.frame_count, self.min_hits)
activated_stracks.append(unconfirmed[itracked])

# Mark unmatched unconfirmed tracks as removed
Expand Down
Loading