Skip to content

Commit 7e786ef

Browse files
authored
feat: added support for images types (stanfordnlp#7872)
1 parent 49c31b6 commit 7e786ef

File tree

2 files changed

+155
-26
lines changed

2 files changed

+155
-26
lines changed

dspy/adapters/types/image.py

Lines changed: 45 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Any, Dict, List, Union
55
from urllib.parse import urlparse
66
import re
7+
import mimetypes
78

89
import pydantic
910
import requests
@@ -79,23 +80,23 @@ def is_url(string: str) -> bool:
7980

8081
def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_images: bool = False) -> str:
8182
"""
82-
Encode an image to a base64 data URI.
83+
Encode an image or file to a base64 data URI.
8384
8485
Args:
85-
image: The image to encode. Can be a PIL Image, file path, URL, or data URI.
86+
image: The image or file to encode. Can be a PIL Image, file path, URL, or data URI.
8687
download_images: Whether to download images from URLs.
8788
8889
Returns:
89-
str: The data URI of the image or the URL if download_images is False.
90+
str: The data URI of the file or the URL if download_images is False.
9091
9192
Raises:
92-
ValueError: If the image type is not supported.
93+
ValueError: If the file type is not supported.
9394
"""
9495
if isinstance(image, dict) and "url" in image:
9596
# NOTE: Not doing other validation for now
9697
return image["url"]
9798
elif isinstance(image, str):
98-
if image.startswith("data:image/"):
99+
if image.startswith("data:"):
99100
# Already a data URI
100101
return image
101102
elif os.path.isfile(image):
@@ -110,8 +111,8 @@ def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_imag
110111
return image
111112
else:
112113
# Unsupported string format
113-
print(f"Unsupported image string: {image}")
114-
raise ValueError(f"Unsupported image string: {image}")
114+
print(f"Unsupported file string: {image}")
115+
raise ValueError(f"Unsupported file string: {image}")
115116
elif PIL_AVAILABLE and isinstance(image, PILImage.Image):
116117
# PIL Image
117118
return _encode_pil_image(image)
@@ -129,34 +130,52 @@ def encode_image(image: Union[str, bytes, "PILImage.Image", dict], download_imag
129130

130131

131132
def _encode_image_from_file(file_path: str) -> str:
132-
"""Encode an image from a file path to a base64 data URI."""
133-
with open(file_path, "rb") as image_file:
134-
image_data = image_file.read()
135-
file_extension = _get_file_extension(file_path)
136-
encoded_image = base64.b64encode(image_data).decode("utf-8")
137-
return f"data:image/{file_extension};base64,{encoded_image}"
133+
"""Encode a file from a file path to a base64 data URI."""
134+
with open(file_path, "rb") as file:
135+
file_data = file.read()
136+
137+
# Use mimetypes to guess directly from the file path
138+
mime_type, _ = mimetypes.guess_type(file_path)
139+
if mime_type is None:
140+
raise ValueError(f"Could not determine MIME type for file: {file_path}")
141+
142+
encoded_data = base64.b64encode(file_data).decode("utf-8")
143+
return f"data:{mime_type};base64,{encoded_data}"
138144

139145

140146
def _encode_image_from_url(image_url: str) -> str:
141-
"""Encode an image from a URL to a base64 data URI."""
147+
"""Encode a file from a URL to a base64 data URI."""
142148
response = requests.get(image_url)
143149
response.raise_for_status()
144150
content_type = response.headers.get("Content-Type", "")
145-
if content_type.startswith("image/"):
146-
file_extension = content_type.split("/")[-1]
151+
152+
# Use the content type from the response headers if available
153+
if content_type:
154+
mime_type = content_type
147155
else:
148-
# Fallback to file extension from URL or default to 'png'
149-
file_extension = _get_file_extension(image_url) or "png"
150-
encoded_image = base64.b64encode(response.content).decode("utf-8")
151-
return f"data:image/{file_extension};base64,{encoded_image}"
156+
# Try to guess MIME type from URL
157+
mime_type, _ = mimetypes.guess_type(image_url)
158+
if mime_type is None:
159+
raise ValueError(f"Could not determine MIME type for URL: {image_url}")
160+
161+
encoded_data = base64.b64encode(response.content).decode("utf-8")
162+
return f"data:{mime_type};base64,{encoded_data}"
163+
152164

153165
def _encode_pil_image(image: 'PILImage') -> str:
154166
"""Encode a PIL Image object to a base64 data URI."""
155167
buffered = io.BytesIO()
156-
file_extension = (image.format or "PNG").lower()
157-
image.save(buffered, format=file_extension)
158-
encoded_image = base64.b64encode(buffered.getvalue()).decode("utf-8")
159-
return f"data:image/{file_extension};base64,{encoded_image}"
168+
file_format = image.format or "PNG"
169+
image.save(buffered, format=file_format)
170+
171+
# Get the correct MIME type using the image format
172+
file_extension = file_format.lower()
173+
mime_type, _ = mimetypes.guess_type(f"file.{file_extension}")
174+
if mime_type is None:
175+
raise ValueError(f"Could not determine MIME type for image format: {file_format}")
176+
177+
encoded_data = base64.b64encode(buffered.getvalue()).decode("utf-8")
178+
return f"data:{mime_type};base64,{encoded_data}"
160179

161180

162181
def _get_file_extension(path_or_url: str) -> str:
@@ -166,11 +185,11 @@ def _get_file_extension(path_or_url: str) -> str:
166185

167186

168187
def is_image(obj) -> bool:
169-
"""Check if the object is an image or a valid image reference."""
188+
"""Check if the object is an image or a valid media file reference."""
170189
if PIL_AVAILABLE and isinstance(obj, PILImage.Image):
171190
return True
172191
if isinstance(obj, str):
173-
if obj.startswith("data:image/"):
192+
if obj.startswith("data:"):
174193
return True
175194
elif os.path.isfile(obj):
176195
return True

tests/signatures/test_adapter_image.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
from dspy.adapters.types.image import encode_image
1212
import tempfile
1313
import pydantic
14+
import os
15+
1416

1517
@pytest.fixture
1618
def sample_pil_image():
@@ -318,6 +320,114 @@ class OptionalImageSignature(dspy.Signature):
318320
assert result.output == "Hello"
319321
assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 0
320322

323+
324+
def test_pdf_url_support():
325+
"""Test support for PDF files from URLs"""
326+
pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
327+
328+
# Create a dspy.Image object from the PDF URL with download=True
329+
pdf_image = dspy.Image.from_url(pdf_url, download=True)
330+
331+
# The data URI should contain application/pdf in the MIME type
332+
assert "data:application/pdf" in pdf_image.url
333+
assert ";base64," in pdf_image.url
334+
335+
# Test using it in a predictor
336+
class PDFSignature(dspy.Signature):
337+
document: dspy.Image = dspy.InputField(desc="A PDF document")
338+
summary: str = dspy.OutputField(desc="A summary of the PDF")
339+
340+
predictor, lm = setup_predictor(PDFSignature, {"summary": "This is a dummy PDF"})
341+
result = predictor(document=pdf_image)
342+
343+
assert result.summary == "This is a dummy PDF"
344+
assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
345+
346+
# Ensure the URL was properly expanded in messages
347+
messages_str = str(lm.history[-1]["messages"])
348+
assert "application/pdf" in messages_str
349+
350+
351+
def test_different_mime_types():
352+
"""Test support for different file types and MIME type detection"""
353+
# Test with various file types
354+
file_urls = {
355+
"pdf": "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf",
356+
"image": "https://images.dog.ceo/breeds/dane-great/n02109047_8912.jpg",
357+
}
358+
359+
expected_mime_types = {
360+
"pdf": "application/pdf",
361+
"image": "image/jpeg",
362+
}
363+
364+
for file_type, url in file_urls.items():
365+
# Download and encode
366+
encoded = encode_image(url, download_images=True)
367+
368+
# Check for correct MIME type in the encoded data - using 'in' instead of startswith
369+
# to account for possible parameters in the MIME type
370+
assert f"data:{expected_mime_types[file_type]}" in encoded
371+
assert ";base64," in encoded
372+
373+
374+
def test_mime_type_from_response_headers():
375+
"""Test that MIME types from response headers are correctly used"""
376+
# This URL returns proper Content-Type header
377+
pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
378+
379+
# Make an actual request to get the content type from headers
380+
response = requests.get(pdf_url)
381+
expected_mime_type = response.headers.get("Content-Type", "")
382+
383+
# Should be application/pdf or similar
384+
assert "pdf" in expected_mime_type.lower()
385+
386+
# Encode with download to test MIME type from headers
387+
encoded = encode_image(pdf_url, download_images=True)
388+
389+
# The encoded data should contain the correct MIME type
390+
assert "application/pdf" in encoded
391+
assert ";base64," in encoded
392+
393+
394+
def test_pdf_from_file():
395+
"""Test handling a PDF file from disk"""
396+
# Download a PDF to a temporary file
397+
pdf_url = "https://www.w3.org/WAI/ER/tests/xhtml/testfiles/resources/pdf/dummy.pdf"
398+
response = requests.get(pdf_url)
399+
response.raise_for_status()
400+
401+
with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmp_file:
402+
tmp_file.write(response.content)
403+
tmp_file_path = tmp_file.name
404+
405+
try:
406+
# Create a dspy.Image from the file
407+
pdf_image = dspy.Image.from_file(tmp_file_path)
408+
409+
# Check that the MIME type is correct
410+
assert "data:application/pdf" in pdf_image.url
411+
assert ";base64," in pdf_image.url
412+
413+
# Test the image in a predictor
414+
class FilePDFSignature(dspy.Signature):
415+
document: dspy.Image = dspy.InputField(desc="A PDF document from file")
416+
summary: str = dspy.OutputField(desc="A summary of the PDF")
417+
418+
predictor, lm = setup_predictor(FilePDFSignature, {"summary": "This is a PDF from file"})
419+
result = predictor(document=pdf_image)
420+
421+
assert result.summary == "This is a PDF from file"
422+
assert count_messages_with_image_url_pattern(lm.history[-1]["messages"]) == 1
423+
finally:
424+
# Clean up the temporary file
425+
try:
426+
os.unlink(tmp_file_path)
427+
except:
428+
pass
429+
430+
321431
def test_image_repr():
322432
"""Test string representation of Image objects"""
323433
url_image = dspy.Image.from_url("https://example.com/dog.jpg", download=False)

0 commit comments

Comments
 (0)