diff --git a/sixdrepnet/regressor.py b/sixdrepnet/regressor.py index 0c8f619..0db0a6f 100644 --- a/sixdrepnet/regressor.py +++ b/sixdrepnet/regressor.py @@ -78,6 +78,64 @@ def predict(self, img): r = euler[:, 2].cpu().detach().numpy() return p,y,r + + + @torch.no_grad() + def predict_batch(self, images): + """ + Predicts head pose for a batch of face images and returns Euler angles. + + Parameters + ---------- + imgs : list of numpy.ndarray + List of face crops (BGR images) to be predicted. Each element should be a face image. + + Returns + ------- + pitchs : numpy.ndarray + Predicted pitch angles for each image in the batch. + + yaws : numpy.ndarray + Predicted yaw angles for each image in the batch. + + rolls : numpy.ndarray + Predicted roll angles for each image in the batch. + """ + batch = [] + + for img in images: + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + img = Image.fromarray(img) + + img = self.transformations(img) + + batch.append(img) + + # Shape => [B, C, H, W] + batch_tensor = torch.stack(batch) + + if self.gpu != -1: + batch_tensor = batch_tensor.cuda( + self.gpu, + non_blocking=True + ) + + pred = self.model(batch_tensor) + + euler = ( + utils.compute_euler_angles_from_rotation_matrices(pred) + * 180 / np.pi + ) + + euler = euler.cpu().numpy() + + pitch = euler[:, 0].astype(np.float32) + yaw = euler[:, 1].astype(np.float32) + roll = euler[:, 2].astype(np.float32) + + return pitch, yaw, roll def draw_axis(self, img, yaw, pitch, roll, tdx=None, tdy=None, size = 100): diff --git a/test_batch.py b/test_batch.py new file mode 100644 index 0000000..e90c306 --- /dev/null +++ b/test_batch.py @@ -0,0 +1,82 @@ +from sixdrepnet import SixDRepNet +import cv2 + +# ===================================================== +# Create model +# ===================================================== + +model = SixDRepNet() + +# ===================================================== +# Open input video +# ===================================================== + +video_path = "/path/to/video.mp4" + +cap = cv2.VideoCapture(video_path) + +fps = cap.get(cv2.CAP_PROP_FPS) + +width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) +height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + +# ===================================================== +# Read all frames +# ===================================================== + +frames = [] + +while True: + + ret, frame = cap.read() + + if not ret: + break + + frames.append(frame) + +cap.release() + +print(f"Loaded {len(frames)} frames") + +# ===================================================== +# Batch prediction +# ===================================================== + +pitchs, yaws, rolls = model.predict_batch(frames) + +print("Inference done") + +# ===================================================== +# Create output video +# ===================================================== + +output_path = "output.mp4" + +fourcc = cv2.VideoWriter_fourcc(*'mp4v') + +writer = cv2.VideoWriter( + output_path, + fourcc, + fps, + (width, height) +) + +# ===================================================== +# Draw predictions +# ===================================================== + +for i, frame in enumerate(frames): + + model.draw_axis( + frame, + yaws[i], + pitchs[i], + rolls[i] + ) + + writer.write(frame) + +writer.release() + +print(f"Saved result video to: {output_path}") \ No newline at end of file diff --git a/test.py b/test_image.py similarity index 100% rename from test.py rename to test_image.py