Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions sixdrepnet/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,64 @@ def predict(self, img):
r = euler[:, 2].cpu().detach().numpy()

return p,y,r


@torch.no_grad()
def predict_batch(self, images):
"""
Predicts head pose for a batch of face images and returns Euler angles.

Parameters
----------
imgs : list of numpy.ndarray
List of face crops (BGR images) to be predicted. Each element should be a face image.

Returns
-------
pitchs : numpy.ndarray
Predicted pitch angles for each image in the batch.

yaws : numpy.ndarray
Predicted yaw angles for each image in the batch.

rolls : numpy.ndarray
Predicted roll angles for each image in the batch.
"""
batch = []

for img in images:

img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

img = Image.fromarray(img)

img = self.transformations(img)

batch.append(img)

# Shape => [B, C, H, W]
batch_tensor = torch.stack(batch)

if self.gpu != -1:
batch_tensor = batch_tensor.cuda(
self.gpu,
non_blocking=True
)

pred = self.model(batch_tensor)

euler = (
utils.compute_euler_angles_from_rotation_matrices(pred)
* 180 / np.pi
)

euler = euler.cpu().numpy()

pitch = euler[:, 0].astype(np.float32)
yaw = euler[:, 1].astype(np.float32)
roll = euler[:, 2].astype(np.float32)

return pitch, yaw, roll


def draw_axis(self, img, yaw, pitch, roll, tdx=None, tdy=None, size = 100):
Expand Down
82 changes: 82 additions & 0 deletions test_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from sixdrepnet import SixDRepNet
import cv2

# =====================================================
# Create model
# =====================================================

model = SixDRepNet()

# =====================================================
# Open input video
# =====================================================

video_path = "/path/to/video.mp4"

cap = cv2.VideoCapture(video_path)

fps = cap.get(cv2.CAP_PROP_FPS)

width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

# =====================================================
# Read all frames
# =====================================================

frames = []

while True:

ret, frame = cap.read()

if not ret:
break

frames.append(frame)

cap.release()

print(f"Loaded {len(frames)} frames")

# =====================================================
# Batch prediction
# =====================================================

pitchs, yaws, rolls = model.predict_batch(frames)

print("Inference done")

# =====================================================
# Create output video
# =====================================================

output_path = "output.mp4"

fourcc = cv2.VideoWriter_fourcc(*'mp4v')

writer = cv2.VideoWriter(
output_path,
fourcc,
fps,
(width, height)
)

# =====================================================
# Draw predictions
# =====================================================

for i, frame in enumerate(frames):

model.draw_axis(
frame,
yaws[i],
pitchs[i],
rolls[i]
)

writer.write(frame)

writer.release()

print(f"Saved result video to: {output_path}")
File renamed without changes.