Skip to content

[WIP] Integration with DeepLabCut 3.0 - PyTorch Engine #121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 65 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
65 commits
Select commit Hold shift + click to select a range
b11743c
Remove TensorFlow references
dikraMasrour Aug 9, 2024
c8f5839
Add comments on TF questions for changing dlc-live pipeline to Pytorch
AnnaStuckert Feb 24, 2025
030c42c
Vanilla pytorch inference done; Commenting out TensorFlow references …
dikraMasrour Feb 24, 2025
d75339a
change testing directory for Anna and allow to run on CPU
AnnaStuckert Feb 24, 2025
4de5829
Fix display, clean code, adapt frame processing, onnx inference
dikraMasrour Feb 24, 2025
40a9dd4
add screenshots
n-poulsen Feb 24, 2025
2fc48f1
Fix CPU inference crash + GPU (cuda) & CPU support for Pytorch and ON…
dikraMasrour Feb 24, 2025
125a072
Video analysis feature
AnnaStuckert Feb 24, 2025
9d2350e
Improvements on benchmark_pytorch.py
AnnaStuckert Feb 24, 2025
13b8fae
Implement TensorRT optimisation on ONNX models and FP16 precision inf…
dikraMasrour Feb 24, 2025
8b41df6
Avs live feed
AnnaStuckert Feb 24, 2025
d1e7df9
Tutorial notebook in progress
dikraMasrour Feb 24, 2025
2fb39fc
bug fixing h5 saving for live video feed
AnnaStuckert Feb 24, 2025
98b3a13
add code to save numbers in csv and h5 as numbers, not tensor(number)
AnnaStuckert Feb 24, 2025
e6a914a
add timestamp suffix to videos and csv/h5 files
AnnaStuckert Feb 24, 2025
29c8122
fix live inference and display, black and isort
dikraMasrour Feb 24, 2025
9508a79
cleaning out unused files
AnnaStuckert Feb 24, 2025
d61f892
update docstrings, clean dlclive script
dikraMasrour Feb 24, 2025
f527d38
Continued DeepLabCut-Live implementation for DeepLabCut 3.0
n-poulsen Sep 18, 2024
4c07e04
working on README
n-poulsen Feb 27, 2025
82daf43
improved docs
n-poulsen Feb 27, 2025
40005ae
improved docs for PyTorch code
n-poulsen Feb 27, 2025
c6a1f69
improved readme
n-poulsen Feb 27, 2025
d738a2c
fix default top down dynamic cropping parameters
n-poulsen Feb 27, 2025
9524092
Update README.md
MMathisLab May 5, 2025
6db3901
Update .gitignore
maximpavliv May 28, 2025
b7a527e
CI/CD update python version
maximpavliv May 30, 2025
8c2f1af
CI/CD update actions versions
maximpavliv May 30, 2025
15c1364
CI/CD update trigger events
maximpavliv May 30, 2025
4444e3d
CI/CD update MacOS version
maximpavliv May 30, 2025
0fefc95
dlclibrary set version to >=0.0.6
maximpavliv Jun 4, 2025
85ad7cb
Poetry lock
maximpavliv Jun 4, 2025
c479b11
Poetry lock
maximpavliv Jun 4, 2025
acd7e3e
Pyproject.toml update tensorflow installation
maximpavliv Jun 4, 2025
d70b7e8
Poetry lock
maximpavliv Jun 4, 2025
114c55a
Install specific tensorflow-io-gcs-filesystem for windows
maximpavliv Jun 4, 2025
6ed7faf
Poetry lock
maximpavliv Jun 4, 2025
594dee8
Update deprecated section name
maximpavliv Jun 4, 2025
b7e3ae2
Poetry lock
maximpavliv Jun 4, 2025
fb3424c
Merge branch 'maxim/fix_cicd' into dlclive3
maximpavliv Jun 4, 2025
751abea
Poetry lock
maximpavliv Jun 4, 2025
d27e8e1
CI/CD install Tensorflow
maximpavliv May 30, 2025
55def2e
Fix missing DLCLive precision attribute
maximpavliv May 30, 2025
87ac706
benchmark_pytorch fix imports order resulting in crash
maximpavliv Jun 3, 2025
d8f8eb0
Correct arg name
maximpavliv Jun 4, 2025
02c1a27
Benchmark pytorch: remove snapshot argument
maximpavliv Jun 5, 2025
fb1314e
Fix incorrect bbox unpacking
maximpavliv Jun 5, 2025
a00af6c
Benchmark pytorch: fix read config
maximpavliv Jun 5, 2025
58a094a
Benchmark pytorch: rename method
maximpavliv Jun 5, 2025
77c2f43
Benchmark pytorch: fix loop
maximpavliv Jun 5, 2025
2b7c520
Benchmark pytorch - update display params
maximpavliv Jun 5, 2025
45c85a5
Formatting
maximpavliv Jun 5, 2025
5b10eb1
Benchmark pytorch: change save_dir default value
maximpavliv Jun 5, 2025
7227c98
Benchmark pytorch: extract setup_video_writer()
maximpavliv Jun 5, 2025
ad62af4
Benchmark pytorch: extract draw_pose_and_write()
maximpavliv Jun 5, 2025
014b96b
Add detector_transform and pose_transform
maximpavliv Jun 11, 2025
e03fb1d
Display: fix multi-animal frame display
maximpavliv Jun 11, 2025
045de53
Benchmark pytorch: add single_animal arg
maximpavliv Jun 11, 2025
b493bef
Benchmark pytorch: dont setup videowriter if not needed
maximpavliv Jun 11, 2025
dfe5817
Benchmark pytorch: fix save_poses_to_files()
maximpavliv Jun 12, 2025
c6ee1ba
Benchmark pytorch: docstring
maximpavliv Jun 13, 2025
9b5d8f1
Benchmark pytorch: formatting
maximpavliv Jun 13, 2025
a29415d
Benchmark pytorch: introduce n_frames and progress bar
maximpavliv Jun 13, 2025
3660dcf
Benchmark pytorch: remove try-except block
maximpavliv Jun 13, 2025
f8fb374
Delete poetry.lock
MMathisLab Jun 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Avs live feed
  • Loading branch information
AnnaStuckert authored and n-poulsen committed Feb 24, 2025
commit 8b41df62019597e28d4e27a31906f3998dd48e80
317 changes: 317 additions & 0 deletions dlclive/LiveVideoInference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
import csv
import os
import platform
import subprocess
import sys
import time

import colorcet as cc
import cv2
import h5py
import numpy as np
import torch
from PIL import ImageColor
from pip._internal.operations import freeze

from dlclive import VERSION, DLCLive


def get_system_info() -> dict:
"""Return summary info for system running benchmark.

Returns
-------
dict
Dictionary containing the following system information:
* ``host_name`` (str): name of machine
* ``op_sys`` (str): operating system
* ``python`` (str): path to python (which conda/virtual environment)
* ``device`` (tuple): (device type (``'GPU'`` or ``'CPU'```), device information)
* ``freeze`` (list): list of installed packages and versions
* ``python_version`` (str): python version
* ``git_hash`` (str, None): If installed from git repository, hash of HEAD commit
* ``dlclive_version`` (str): dlclive version from :data:`dlclive.VERSION`
"""

# Get OS and host name
op_sys = platform.platform()
host_name = platform.node().replace(" ", "")

# Get Python executable path
if platform.system() == "Windows":
host_python = sys.executable.split(os.path.sep)[-2]
else:
host_python = sys.executable.split(os.path.sep)[-3]

# Try to get git hash if possible
git_hash = None
dlc_basedir = os.path.dirname(os.path.dirname(__file__))
try:
git_hash = (
subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=dlc_basedir)
.decode("utf-8")
.strip()
)
except subprocess.CalledProcessError:
# Not installed from git repo, e.g., pypi
pass

# Get device info (GPU or CPU)
if torch.cuda.is_available():
dev_type = "GPU"
dev = [torch.cuda.get_device_name(torch.cuda.current_device())]
else:
from cpuinfo import get_cpu_info

dev_type = "CPU"
dev = [get_cpu_info()["brand_raw"]]

return {
"host_name": host_name,
"op_sys": op_sys,
"python": host_python,
"device_type": dev_type,
"device": dev,
"freeze": list(freeze.freeze()),
"python_version": sys.version,
"git_hash": git_hash,
"dlclive_version": VERSION,
}


def analyze_live_video(
model_path: str,
model_type: str,
device: str,
camera: float = 0,
experiment_name: str = "Test",
precision: str = "FP32",
snapshot: str = None,
display=True,
pcutoff=0.5,
display_radius=5,
resize=None,
cropping=None, # Adding cropping to the function parameters
dynamic=(False, 0.5, 10),
save_poses=False,
save_dir="model_predictions",
draw_keypoint_names=False,
cmap="bmy",
get_sys_info=True,
):
"""
Analyze a video to track keypoints using an imported DeepLabCut model, visualize keypoints on the video, and optionally save the keypoint data and the labelled video.

Parameters:
-----------
camera : float, default=0 (webcam)
The camera to record the live video from
experiment_name : str, default = "Test"
Prefix to label generated pose and video files
pcutoff : float, optional, default=0.5
The probability cutoff value below which keypoints are not visualized.
display_radius : int, optional, default=5
The radius of the circles drawn to represent keypoints on the video frames.
resize : tuple of int (width, height) or None, optional, default=None
The size to which the frames should be resized. If None, the frames are not resized.
cropping : list of int, optional, default=None
Cropping parameters in pixel number: [x1, x2, y1, y2]
save_poses : bool, optional, default=False
Whether to save the detected poses to CSV and HDF5 files.
save_dir : str, optional, default="model_predictions"
The directory where the output video and pose data will be saved.
draw_keypoint_names : bool, optional, default=False
Whether to draw the names of the keypoints on the video frames.
cmap : str, optional, default="bmy"
The colormap from the colorcet library to use for keypoint visualization.

Returns:
--------
poses : list of dict
A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.
"""
# Create the DLCLive object with cropping
dlc_live = DLCLive(
path=model_path,
model_type=model_type,
device=device,
display=display,
resize=resize,
cropping=cropping, # Pass the cropping parameter
dynamic=dynamic,
precision=precision,
snapshot=snapshot,
)

# Ensure save directory exists
os.makedirs(name=save_dir, exist_ok=True)

# Load video
cap = cv2.VideoCapture(camera)
if not cap.isOpened():
print(f"Error: Could not open video file {camera}")
return

# Start empty dict to save poses to for each frame
poses, times = [], []
frame_index = 0

# Retrieve bodypart names and number of keypoints
bodyparts = dlc_live.cfg["metadata"]["bodyparts"]
num_keypoints = len(bodyparts)

# Set colors and convert to RGB
cmap_colors = getattr(cc, cmap)
colors = [
ImageColor.getrgb(color)
for color in cmap_colors[:: int(len(cmap_colors) / num_keypoints)]
]

# Define output video path
output_video_path = os.path.join(
save_dir, f"{experiment_name}_DLCLIVE_LABELLED.mp4"
)

# Get video writer setup
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

vwriter = cv2.VideoWriter(
filename=output_video_path,
fourcc=fourcc,
fps=fps,
frameSize=(frame_width, frame_height),
)

while True:
start_time = time.time()

ret, frame = cap.read()
if not ret:
break

try:
if frame_index == 0:
pose = dlc_live.init_inference(frame) # load DLC model
else:
pose = dlc_live.get_pose(frame)
except Exception as e:
print(f"Error analyzing frame {frame_index}: {e}")
continue

end_time = time.time()
processing_time = end_time - start_time
print(f"Frame {frame_index} processing time: {processing_time:.4f} seconds")

poses.append({"frame": frame_index, "pose": pose})

# Visualize keypoints
this_pose = pose[0]["poses"][0][0]
for j in range(this_pose.shape[0]):
if this_pose[j, 2] > pcutoff:
x, y = map(int, this_pose[j, :2])
cv2.circle(
frame,
center=(x, y),
radius=display_radius,
color=colors[j],
thickness=-1,
)

if draw_keypoint_names:
cv2.putText(
frame,
text=bodyparts[j],
org=(x + 10, y),
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=0.5,
color=colors[j],
thickness=1,
lineType=cv2.LINE_AA,
)

vwriter.write(image=frame)
frame_index += 1

# Display the frame
if display:
cv2.imshow("DLCLive", frame)

# Add key press check for quitting
if cv2.waitKey(1) & 0xFF == ord("q"):
break

cap.release()
vwriter.release()
cv2.destroyAllWindows()

if get_sys_info:
print(get_system_info())

if save_poses:
save_poses_to_files(experiment_name, save_dir, bodyparts, poses)

return poses, times


def save_poses_to_files(experiment_name, save_dir, bodyparts, poses):
"""
Save the keypoint poses detected in the video to CSV and HDF5 files.

Parameters:
-----------
experiment_name : str
Name of the experiment, used as a prefix for saving files.
save_dir : str
The directory where the pose data files will be saved.
bodyparts : list of str
A list of body part names corresponding to the keypoints.
poses : list of dict
A list of dictionaries where each dictionary contains the frame number and the corresponding pose data.

Returns:
--------
None
"""
base_filename = os.path.splitext(os.path.basename(experiment_name))[0]
csv_save_path = os.path.join(save_dir, f"{base_filename}_poses.csv")
h5_save_path = os.path.join(save_dir, f"{base_filename}_poses.h5")

# Save to CSV
with open(csv_save_path, mode="w", newline="") as file:
writer = csv.writer(file)
header = ["frame"] + [
f"{bp}_{axis}" for bp in bodyparts for axis in ["x", "y", "confidence"]
]
writer.writerow(header)
for entry in poses:
frame_num = entry["frame"]
pose_data = entry["pose"][0]["poses"][0][0]
# Convert tensor data to numeric values
row = [frame_num] + [
item.item() if isinstance(item, torch.Tensor) else item
for kp in pose_data
for item in kp
]
writer.writerow(row)

# Save to HDF5
with h5py.File(h5_save_path, "w") as hf:
hf.create_dataset(name="frames", data=[entry["frame"] for entry in poses])
for i, bp in enumerate(bodyparts):
hf.create_dataset(
name=f"{bp}_x",
data=[
(
entry["pose"][0]["poses"][0][0][i, 0].item()
if isinstance(
entry["pose"][0]["poses"][0][0][i, 0], torch.Tensor
)
else entry["pose"][0]["poses"][0][0][i, 0]
)
for entry in poses
],
)
hf.create_dataset(
6 changes: 4 additions & 2 deletions dlclive/benchmark_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def analyze_video(
model_path: str,
model_type: str,
device: str,
precision:str,
precision: str = "FP32",
snapshot: str = None,
display=True,
pcutoff=0.5,
display_radius=5,
Expand Down Expand Up @@ -176,7 +177,8 @@ def analyze_video(
resize=resize,
cropping=cropping, # Pass the cropping parameter
dynamic=dynamic,
precision=precision
precision=precision,
snapshot=snapshot
)

# Ensure save directory exists
Expand Down
2 changes: 1 addition & 1 deletion dlclive/dlclive.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def get_pose(self, frame=None, **kwargs):
with torch.no_grad():
start = time.time()
outputs = self.pose_model(frame)
torch.cuda.synchronize()
#torch.cuda.synchronize()
end = time.time()
inf_time = end - start
print(f"PyTorch inference took {inf_time} sec")
Expand Down