Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions roboflow/cli/handlers/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,15 @@ def _infer(args): # noqa: ANN001
from roboflow.models.keypoint_detection import KeypointDetectionModel
from roboflow.models.object_detection import ObjectDetectionModel
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
from roboflow.models.vlm import VLMModel

model_class_map = {
"object-detection": ObjectDetectionModel,
"classification": ClassificationModel,
"instance-segmentation": InstanceSegmentationModel,
"semantic-segmentation": SemanticSegmentationModel,
"keypoint-detection": KeypointDetectionModel,
"text-image-pairs": VLMModel,
}

model_cls = model_class_map.get(project_type)
Expand All @@ -97,15 +99,25 @@ def _infer(args): # noqa: ANN001
kwargs["overlap"] = int(args.overlap * 100)

try:
group = model.predict(args.file, **kwargs)
result = model.predict(args.file, **kwargs)
except Exception as exc:
output_error(args, f"Inference failed: {exc}")
return

# VLM models return raw dict response; pass through as-is.
if isinstance(result, dict):
if getattr(args, "json", False):
output(args, result)
else:
import json as _json

output(args, None, text=_json.dumps(result, indent=2))
return

# Serialize predictions for JSON output
if getattr(args, "json", False):
predictions = []
for pred in group:
for pred in result:
if hasattr(pred, "json"):
predictions.append(pred.json())
elif hasattr(pred, "__dict__"):
Expand All @@ -114,4 +126,4 @@ def _infer(args): # noqa: ANN001
predictions.append(str(pred))
output(args, predictions)
else:
output(args, None, text=str(group))
output(args, None, text=str(result))
1 change: 1 addition & 0 deletions roboflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def get_conditional_configuration_variable(key, default):
TYPE_INSTANCE_SEGMENTATION = "instance-segmentation"
TYPE_SEMANTIC_SEGMENTATION = "semantic-segmentation"
TYPE_KEYPOINT_DETECTION = "keypoint-detection"
TYPE_TEXT_IMAGE_PAIRS = "text-image-pairs"

TASK_DET = "det"
TASK_SEG = "seg"
Expand Down
12 changes: 12 additions & 0 deletions roboflow/core/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TYPE_KEYPOINT_DETECTION,
TYPE_OBJECT_DETECTION,
TYPE_SEMANTIC_SEGMENTATION,
TYPE_TEXT_IMAGE_PAIRS,
UNIVERSE_URL,
)
from roboflow.core.dataset import Dataset
Expand All @@ -30,6 +31,7 @@
from roboflow.models.keypoint_detection import KeypointDetectionModel
from roboflow.models.object_detection import ObjectDetectionModel
from roboflow.models.semantic_segmentation import SemanticSegmentationModel
from roboflow.models.vlm import VLMModel
from roboflow.util.annotations import amend_data_yaml
from roboflow.util.general import extract_zip, write_line
from roboflow.util.model_processor import process, validate_model_type_for_project
Expand Down Expand Up @@ -133,6 +135,16 @@ def __init__(
self.model = SemanticSegmentationModel(self.__api_key, self.id)
elif self.type == TYPE_KEYPOINT_DETECTION:
self.model = KeypointDetectionModel(self.__api_key, self.id, version=version_without_workspace)
elif self.type == TYPE_TEXT_IMAGE_PAIRS:
self.model = VLMModel(
self.__api_key,
self.id,
self.name,
version_without_workspace,
local=local,
colors=self.colors,
preprocessing=self.preprocessing,
)
else:
self.model = None

Expand Down
95 changes: 95 additions & 0 deletions roboflow/models/vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Vision-language (text-image-pairs) hosted inference.

Wraps the serverless endpoint for VLM-style projects (e.g. PaliGemma).
Unlike detection/classification models, the response shape is free-form:
captions, VQA answers, OCR text, or tokenized detections depending on the
underlying model. `predict` returns the raw serverless JSON unchanged so
callers can interpret the payload for their specific model.
"""

from __future__ import annotations

import base64
import io
import os
import urllib.parse
from typing import Any, Optional

import requests
from PIL import Image

from roboflow.models.inference import InferenceModel
from roboflow.util.image_utils import check_image_url


class VLMModel(InferenceModel):
"""Run inference on a hosted text-image-pairs (VLM) model."""

def __init__(
self,
api_key: str,
id: str,
name: Optional[str] = None,
version: Optional[str] = None,
local: Optional[str] = None,
colors: Optional[dict] = None,
preprocessing: Optional[dict] = None,
) -> None:
super().__init__(api_key, id, version=version)
self.__api_key = api_key
self.id = id
self.name = name
self.version = version
self.base_url = local if local else "https://serverless.roboflow.com/"
self.colors = {} if colors is None else colors
self.preprocessing = {} if preprocessing is None else preprocessing

def _endpoint(self) -> str:
parts = self.id.rsplit("/")
without_workspace = parts[1]
version = self.version
if not version and len(parts) > 2:
version = parts[2]
base = self.base_url if self.base_url.endswith("/") else self.base_url + "/"
return f"{base}{without_workspace}/{version}"

def predict(self, image_path: str, **kwargs: Any) -> dict: # type: ignore[override]
"""Run inference and return the raw serverless response.

Args:
image_path: local path or http(s) URL to an image.
**kwargs: extra query params forwarded to the endpoint.

Returns:
The raw JSON response as a dict. Shape depends on the underlying
VLM (e.g. `{"response": {">": "..."}}` for PaliGemma).
"""
is_url = urllib.parse.urlparse(image_path).scheme in ("http", "https")

params: dict[str, Any] = {"api_key": self.__api_key}
params.update(kwargs)

if is_url:
if not check_image_url(image_path):
raise Exception(f"Image URL is not reachable: {image_path}")
params["image"] = image_path
url = f"{self._endpoint()}?{urllib.parse.urlencode(params)}"
resp = requests.get(url)
else:
if not os.path.exists(image_path):
raise Exception(f"Image does not exist at {image_path}!")
image = Image.open(image_path).convert("RGB")
buffered = io.BytesIO()
image.save(buffered, quality=90, format="JPEG")
img_b64 = base64.b64encode(buffered.getvalue()).decode("ascii")
url = f"{self._endpoint()}?{urllib.parse.urlencode(params)}"
resp = requests.post(
url,
data=img_b64,
headers={"Content-Type": "application/x-www-form-urlencoded"},
)

if resp.status_code != 200:
raise Exception(resp.text)

return resp.json()
58 changes: 58 additions & 0 deletions tests/cli/test_infer_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,5 +149,63 @@ def test_infer_confidence_converted_to_percentage(self, mock_model_cls: MagicMoc
mock_model.predict.assert_called_once_with("test.jpg", confidence=70, overlap=30)


class TestInferVLM(unittest.TestCase):
"""VLM (text-image-pairs) path returns raw dict passthrough."""

def _make_args(self, **kwargs: object) -> types.SimpleNamespace:
defaults = {
"json": False,
"api_key": "test-key",
"workspace": "test-ws",
"model": "test-project/1",
"file": "https://example.com/img.jpg",
"confidence": 0.5,
"overlap": 0.5,
"type": "text-image-pairs",
}
defaults.update(kwargs)
return types.SimpleNamespace(**defaults)

@patch("roboflow.models.vlm.VLMModel")
def test_infer_vlm_json_passthrough(self, mock_model_cls: MagicMock) -> None:
from roboflow.cli.handlers.infer import _infer

raw = {"inference_id": "abc", "response": {">": "caption text"}}
mock_model = MagicMock()
mock_model.predict.return_value = raw
mock_model_cls.return_value = mock_model

args = self._make_args(json=True)
buf = io.StringIO()
old_stdout = sys.stdout
sys.stdout = buf
try:
_infer(args)
finally:
sys.stdout = old_stdout

result = json.loads(buf.getvalue())
self.assertEqual(result, raw)

@patch("roboflow.models.vlm.VLMModel")
def test_infer_vlm_skips_confidence_overlap(self, mock_model_cls: MagicMock) -> None:
from roboflow.cli.handlers.infer import _infer

mock_model = MagicMock()
mock_model.predict.return_value = {"ok": True}
mock_model_cls.return_value = mock_model

args = self._make_args(confidence=0.7, overlap=0.3)
buf = io.StringIO()
old_stdout = sys.stdout
sys.stdout = buf
try:
_infer(args)
finally:
sys.stdout = old_stdout

mock_model.predict.assert_called_once_with("https://example.com/img.jpg")


if __name__ == "__main__":
unittest.main()
82 changes: 82 additions & 0 deletions tests/models/test_vlm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Unit tests for roboflow.models.vlm.VLMModel."""

from __future__ import annotations

import unittest
from unittest.mock import MagicMock, patch

from roboflow.models.vlm import VLMModel


class TestVLMModel(unittest.TestCase):
def _make(self) -> VLMModel:
return VLMModel(api_key="k", id="ws/proj/3", name="proj", version="3")

@patch("roboflow.models.vlm.check_image_url", return_value=True)
@patch("roboflow.models.vlm.requests.get")
def test_predict_url_returns_raw_dict(self, mock_get: MagicMock, _chk: MagicMock) -> None:
mock_get.return_value = MagicMock(
status_code=200,
json=lambda: {"response": {">": "box<loc_1><loc_2><loc_3><loc_4>"}},
)
model = self._make()
result = model.predict("https://example.com/img.jpg")

self.assertEqual(result, {"response": {">": "box<loc_1><loc_2><loc_3><loc_4>"}})
called_url = mock_get.call_args[0][0]
self.assertIn("https://serverless.roboflow.com/proj/3", called_url)
self.assertIn("api_key=k", called_url)
self.assertIn("image=", called_url)

@patch("roboflow.models.vlm.check_image_url", return_value=True)
@patch("roboflow.models.vlm.requests.get")
def test_predict_forwards_extra_kwargs_as_query(self, mock_get: MagicMock, _chk: MagicMock) -> None:
mock_get.return_value = MagicMock(status_code=200, json=lambda: {"ok": True})
self._make().predict("https://example.com/img.jpg", prompt="caption")

called_url = mock_get.call_args[0][0]
self.assertIn("prompt=caption", called_url)

@patch("roboflow.models.vlm.check_image_url", return_value=True)
@patch("roboflow.models.vlm.requests.get")
def test_predict_non_200_raises(self, mock_get: MagicMock, _chk: MagicMock) -> None:
mock_get.return_value = MagicMock(status_code=401, text="unauthorized")
with self.assertRaises(Exception) as ctx:
self._make().predict("https://example.com/img.jpg")
self.assertIn("unauthorized", str(ctx.exception))

@patch("roboflow.models.vlm.os.path.exists", return_value=True)
@patch("roboflow.models.vlm.Image.open")
@patch("roboflow.models.vlm.requests.post")
def test_predict_local_path_posts_base64(
self, mock_post: MagicMock, mock_open: MagicMock, _exists: MagicMock
) -> None:
mock_img = MagicMock()
mock_img.convert.return_value = mock_img

def _save(buf: object, **_kw: object) -> None:
buf.write(b"fakejpeg") # type: ignore[attr-defined]

mock_img.save.side_effect = _save
mock_open.return_value = mock_img
mock_post.return_value = MagicMock(status_code=200, json=lambda: {"ok": True})

result = self._make().predict("/tmp/x.jpg")
self.assertEqual(result, {"ok": True})
_, kwargs = mock_post.call_args
self.assertEqual(kwargs["headers"], {"Content-Type": "application/x-www-form-urlencoded"})
self.assertIsInstance(kwargs["data"], str)

def test_predict_missing_local_file_raises(self) -> None:
with self.assertRaises(Exception) as ctx:
self._make().predict("/definitely/not/a/real/path.jpg")
self.assertIn("does not exist", str(ctx.exception))

def test_endpoint_uses_id_parts_when_version_unset(self) -> None:
model = VLMModel(api_key="k", id="ws/proj/7")
model.version = None
self.assertEqual(model._endpoint(), "https://serverless.roboflow.com/proj/7")


if __name__ == "__main__":
unittest.main()
Loading