Skip to content

Commit 0616b11

Browse files
authored
Merge pull request #1442 from roboflow/florence-2-lora
inference-exp: Florence 2 LoRA
2 parents 1e139e3 + 1e496c9 commit 0616b11

File tree

7 files changed

+425
-30
lines changed

7 files changed

+425
-30
lines changed

inference_experimental/inference_exp/models/auto_loaders/models_registry.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,10 @@
121121
module_name="inference_exp.models.paligemma.paligemma_hf",
122122
class_name="PaliGemmaHF",
123123
),
124+
("florence-2", VLM_TASK, BackendType.HF): LazyClass(
125+
module_name="inference_exp.models.florence2.florence2_hf",
126+
class_name="Florence2HF",
127+
),
124128
("clip", EMBEDDING_TASK, BackendType.TORCH): LazyClass(
125129
module_name="inference_exp.models.clip.clip_pytorch",
126130
class_name="ClipTorch",

inference_experimental/inference_exp/models/florence2/florence2_hf.py

Lines changed: 35 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from typing import List, Literal, Optional, Tuple, Union
2+
import os
23

34
import cv2
45
import numpy as np
56
import torch
7+
from peft import LoraConfig, PeftModel
68
from inference_exp import Detections, InstanceDetections
79
from inference_exp.configuration import DEFAULT_DEVICE
810
from inference_exp.entities import ImageDimensions
@@ -18,9 +20,9 @@
1820
"very_detailed": "<MORE_DETAILED_CAPTION>",
1921
}
2022
LABEL_MODE2TASK = {
21-
"roi": "<REGION_PROPOSAL>",
22-
"class": "<OD>",
23-
"caption": "<DENSE_REGION_CAPTION>",
23+
"rois": "<REGION_PROPOSAL>",
24+
"classes": "<OD>",
25+
"captions": "<DENSE_REGION_CAPTION>",
2426
}
2527
LOC_BINS = 1000
2628

@@ -35,15 +37,36 @@ def from_pretrained(
3537
**kwargs,
3638
) -> "Florence2HF":
3739
torch_dtype = torch.float16 if device.type == "cuda" else torch.float32
38-
model = AutoModelForCausalLM.from_pretrained(
39-
model_name_or_path,
40-
torch_dtype=torch_dtype,
41-
trust_remote_code=True,
42-
).to(device)
43-
processor = AutoProcessor.from_pretrained(
44-
model_name_or_path,
45-
trust_remote_code=True,
46-
)
40+
41+
adapter_config_path = os.path.join(model_name_or_path, "adapter_config.json")
42+
if os.path.exists(adapter_config_path):
43+
base_model_path = os.path.join(model_name_or_path, "base")
44+
model = AutoModelForCausalLM.from_pretrained(
45+
base_model_path,
46+
torch_dtype=torch_dtype,
47+
trust_remote_code=True,
48+
local_files_only=True,
49+
)
50+
model = PeftModel.from_pretrained(model, model_name_or_path)
51+
model.merge_and_unload()
52+
model.to(device)
53+
54+
processor = AutoProcessor.from_pretrained(
55+
base_model_path, trust_remote_code=True, local_files_only=True
56+
)
57+
else:
58+
model = AutoModelForCausalLM.from_pretrained(
59+
model_name_or_path,
60+
torch_dtype=torch_dtype,
61+
trust_remote_code=True,
62+
local_files_only=True,
63+
).to(device)
64+
processor = AutoProcessor.from_pretrained(
65+
model_name_or_path,
66+
trust_remote_code=True,
67+
local_files_only=True,
68+
)
69+
4770
return cls(
4871
model=model, processor=processor, device=device, torch_dtype=torch_dtype
4972
)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os.path
2+
import zipfile
3+
4+
import cv2
5+
import numpy as np
6+
import pytest
7+
import requests
8+
import torch
9+
import torchvision.io
10+
from filelock import FileLock
11+
from PIL import Image
12+
13+
ASSETS_DIR = os.path.abspath(
14+
os.path.join(os.path.dirname(__file__), "models", "assets")
15+
)
16+
DOG_IMAGE_PATH = os.path.join(ASSETS_DIR, "dog.jpeg")
17+
DOG_IMAGE_URL = "https://media.roboflow.com/dog.jpeg"
18+
19+
20+
def _download_if_not_exists(file_path: str, url: str, lock_timeout: int = 120) -> None:
21+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
22+
lock_path = f"{file_path}.lock"
23+
with FileLock(lock_file=lock_path, timeout=lock_timeout):
24+
if os.path.exists(file_path):
25+
return None
26+
with requests.get(url, stream=True) as response:
27+
response.raise_for_status()
28+
with open(file_path, "wb") as f:
29+
for chunk in response.iter_content(chunk_size=8192):
30+
if chunk:
31+
f.write(chunk)
32+
33+
34+
@pytest.fixture(scope="function")
35+
def dog_image_numpy() -> np.ndarray:
36+
_download_if_not_exists(file_path=DOG_IMAGE_PATH, url=DOG_IMAGE_URL)
37+
image = cv2.imread(DOG_IMAGE_PATH)
38+
assert image is not None, "Could not load test image"
39+
return image
40+
41+
42+
@pytest.fixture(scope="function")
43+
def dog_image_torch() -> torch.Tensor:
44+
_download_if_not_exists(file_path=DOG_IMAGE_PATH, url=DOG_IMAGE_URL)
45+
return torchvision.io.read_image(DOG_IMAGE_PATH)
46+
47+
48+
@pytest.fixture(scope="function")
49+
def dog_image_pil() -> Image.Image:
50+
_download_if_not_exists(file_path=DOG_IMAGE_PATH, url=DOG_IMAGE_URL)
51+
return Image.open(DOG_IMAGE_PATH)
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
from inference_exp import AutoModel
5+
6+
7+
@pytest.mark.e2e_model_inference
8+
def test_florence2_base_model(dog_image_numpy: np.ndarray):
9+
# GIVEN
10+
model = AutoModel.from_pretrained("florence-2-base")
11+
12+
# WHEN
13+
captions = model.caption_image(dog_image_numpy)
14+
15+
# THEN
16+
assert isinstance(captions, list)
17+
assert len(captions) == 1
18+
assert isinstance(captions[0], str)
19+
assert captions[0] == "A man carrying a blue dog on his back."
20+
21+
22+
@pytest.mark.e2e_model_inference
23+
def test_florence2_lora_model(
24+
dog_image_numpy: np.ndarray, dog_image_torch: torch.Tensor
25+
):
26+
# GIVEN
27+
model = AutoModel.from_pretrained("florence-2-lora-test")
28+
29+
# WHEN
30+
captions = model.caption_image(dog_image_numpy)
31+
32+
# THEN
33+
assert isinstance(captions, list)
34+
assert len(captions) == 1
35+
assert isinstance(captions[0], str)
36+
assert captions[0] == "Disease"

inference_experimental/tests/integration_tests/models/conftest.py

Lines changed: 42 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os.path
2+
import zipfile
23

34
import cv2
45
import numpy as np
@@ -11,13 +12,18 @@
1112

1213
ASSETS_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "assets"))
1314
MODELS_DIR = os.path.join(ASSETS_DIR, "models")
14-
DOG_IMAGE_PATH = os.path.join(ASSETS_DIR, "dog.jpeg")
15-
DOG_IMAGE_URL = "https://media.roboflow.com/dog.jpeg"
1615
CLIP_RN50_TORCH_URL = "https://storage.googleapis.com/roboflow-tests-assets/clip_packages/RN50/torch/model.pt"
1716
CLIP_RN50_ONNX_VISUAL = "https://storage.googleapis.com/roboflow-tests-assets/clip_packages/RN50/onnx/visual.onnx"
1817
CLIP_RN50_ONNX_TEXTUAL = "https://storage.googleapis.com/roboflow-tests-assets/clip_packages/RN50/onnx/textual.onnx"
1918
PE_MODEL_URL = "https://storage.googleapis.com/roboflow-tests-assets/perception-encoder/pe-core-b16-224/model.pt"
2019
PE_CONFIG_URL = "https://storage.googleapis.com/roboflow-tests-assets/perception-encoder/pe-core-b16-224/config.json"
20+
FLORENCE2_BASE_FT_URL = (
21+
"https://storage.googleapis.com/roboflow-tests-assets/florence2/base-ft.zip"
22+
)
23+
FLORENCE2_LARGE_FT_URL = (
24+
"https://storage.googleapis.com/roboflow-tests-assets/florence2/large-ft.zip"
25+
)
26+
OCR_TEST_IMAGE_PATH = os.path.join(ASSETS_DIR, "ocr_test_image.png")
2127

2228

2329
@pytest.fixture(scope="module")
@@ -59,25 +65,13 @@ def perception_encoder_path() -> str:
5965

6066

6167
@pytest.fixture(scope="function")
62-
def dog_image_numpy() -> np.ndarray:
63-
_download_if_not_exists(file_path=DOG_IMAGE_PATH, url=DOG_IMAGE_URL)
64-
image = cv2.imread(DOG_IMAGE_PATH)
65-
assert image is not None, "Could not load test image"
68+
def ocr_test_image_numpy() -> np.ndarray:
69+
"""Returns the OCR test image as a numpy array."""
70+
image = cv2.imread(OCR_TEST_IMAGE_PATH)
71+
assert image is not None, "Could not load OCR test image"
6672
return image
6773

6874

69-
@pytest.fixture(scope="function")
70-
def dog_image_torch() -> torch.Tensor:
71-
_download_if_not_exists(file_path=DOG_IMAGE_PATH, url=DOG_IMAGE_URL)
72-
return torchvision.io.read_image(DOG_IMAGE_PATH)
73-
74-
75-
@pytest.fixture(scope="function")
76-
def dog_image_pil() -> Image.Image:
77-
_download_if_not_exists(file_path=DOG_IMAGE_PATH, url=DOG_IMAGE_URL)
78-
return Image.open(DOG_IMAGE_PATH)
79-
80-
8175
def _download_if_not_exists(file_path: str, url: str, lock_timeout: int = 120) -> None:
8276
os.makedirs(os.path.dirname(file_path), exist_ok=True)
8377
lock_path = f"{file_path}.lock"
@@ -90,3 +84,33 @@ def _download_if_not_exists(file_path: str, url: str, lock_timeout: int = 120) -
9084
for chunk in response.iter_content(chunk_size=8192):
9185
if chunk:
9286
f.write(chunk)
87+
88+
89+
@pytest.fixture(scope="module")
90+
def florence2_base_ft_path() -> str:
91+
package_dir = os.path.join(MODELS_DIR, "florence2")
92+
unzipped_package_path = os.path.join(package_dir, "base-ft")
93+
os.makedirs(package_dir, exist_ok=True)
94+
zip_path = os.path.join(package_dir, "base-ft.zip")
95+
_download_if_not_exists(file_path=zip_path, url=FLORENCE2_BASE_FT_URL)
96+
lock_path = f"{unzipped_package_path}.lock"
97+
with FileLock(lock_path, timeout=120):
98+
if not os.path.exists(unzipped_package_path):
99+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
100+
zip_ref.extractall(package_dir)
101+
return unzipped_package_path
102+
103+
104+
@pytest.fixture(scope="module")
105+
def florence2_large_ft_path() -> str:
106+
package_dir = os.path.join(MODELS_DIR, "florence2")
107+
unzipped_package_path = os.path.join(package_dir, "large-ft")
108+
os.makedirs(package_dir, exist_ok=True)
109+
zip_path = os.path.join(package_dir, "large-ft.zip")
110+
_download_if_not_exists(file_path=zip_path, url=FLORENCE2_LARGE_FT_URL)
111+
lock_path = f"{unzipped_package_path}.lock"
112+
with FileLock(lock_path, timeout=120):
113+
if not os.path.exists(unzipped_package_path):
114+
with zipfile.ZipFile(zip_path, "r") as zip_ref:
115+
zip_ref.extractall(package_dir)
116+
return unzipped_package_path
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import numpy as np
2+
import pytest
3+
import torch
4+
5+
from inference_exp.models.florence2.florence2_hf import Florence2HF
6+
7+
8+
@pytest.fixture(scope="module")
9+
def florence2_model(florence2_base_ft_path: str) -> Florence2HF:
10+
return Florence2HF.from_pretrained(florence2_base_ft_path)
11+
12+
13+
@pytest.mark.slow
14+
def test_classify_image_region(
15+
florence2_model: Florence2HF, dog_image_numpy: np.ndarray
16+
):
17+
# given
18+
xyxy = [100, 100, 300, 300]
19+
# when
20+
result = florence2_model.classify_image_region(images=dog_image_numpy, xyxy=xyxy)
21+
# then
22+
assert result == ["human face"]
23+
24+
25+
@pytest.mark.slow
26+
def test_caption_image_region(
27+
florence2_model: Florence2HF, dog_image_numpy: np.ndarray
28+
):
29+
# given
30+
xyxy = [100, 100, 300, 300]
31+
# when
32+
result = florence2_model.caption_image_region(images=dog_image_numpy, xyxy=xyxy)
33+
# then
34+
assert result == ["human face"]
35+
36+
37+
@pytest.mark.slow
38+
def test_ocr_image_region(
39+
florence2_model: Florence2HF, ocr_test_image_numpy: np.ndarray
40+
):
41+
# TODO: figure out if this is imlementation error? doesnt really seem to work, like just returns text from the whole image
42+
# given
43+
xyxy = [0, 0, 100, 150]
44+
# when
45+
result = florence2_model.ocr_image_region(images=ocr_test_image_numpy, xyxy=xyxy)
46+
# then
47+
assert result == ["This is a test image for OCR."]
48+
49+
50+
@pytest.mark.slow
51+
def test_segment_region(florence2_model: Florence2HF, dog_image_numpy: np.ndarray):
52+
# given
53+
xyxy = [100, 100, 300, 300]
54+
# when
55+
result = florence2_model.segment_region(images=dog_image_numpy, xyxy=xyxy)
56+
# then
57+
assert isinstance(result, list)
58+
assert len(result) == 1
59+
assert result[0].xyxy.shape == (1, 4)
60+
assert torch.allclose(
61+
result[0].xyxy, torch.tensor([[100, 100, 302, 303]], dtype=torch.int32), atol=2
62+
)
63+
assert result[0].mask.shape == (1, 1280, 720)
64+
65+
66+
@pytest.mark.slow
67+
def test_segment_phrase(florence2_model: Florence2HF, dog_image_numpy: np.ndarray):
68+
# when
69+
result = florence2_model.segment_phrase(images=dog_image_numpy, phrase="dog")
70+
# then
71+
assert isinstance(result, list)
72+
assert len(result) == 1
73+
assert result[0].xyxy.shape == (1, 4)
74+
assert torch.allclose(
75+
result[0].xyxy, torch.tensor([[71, 249, 649, 926]], dtype=torch.int32), atol=5
76+
)
77+
assert result[0].mask.shape == (1, 1280, 720)
78+
79+
80+
@pytest.mark.slow
81+
def test_detect_objects(florence2_model: Florence2HF, dog_image_numpy: np.ndarray):
82+
# when
83+
result = florence2_model.detect_objects(images=dog_image_numpy)
84+
# then
85+
assert isinstance(result, list)
86+
assert len(result) == 1
87+
assert result[0].xyxy.shape == (4, 4)
88+
expected_bboxes_metadata = [
89+
{"class_name": "backpack"},
90+
{"class_name": "dog"},
91+
{"class_name": "hat"},
92+
{"class_name": "person"},
93+
]
94+
assert result[0].bboxes_metadata == expected_bboxes_metadata
95+
96+
97+
@pytest.mark.slow
98+
def test_caption_image(florence2_model: Florence2HF, dog_image_numpy: np.ndarray):
99+
# when
100+
result = florence2_model.caption_image(images=dog_image_numpy)
101+
# then
102+
assert result == ["A man carrying a blue dog on his back."]
103+
104+
105+
@pytest.mark.slow
106+
def test_parse_document(florence2_model: Florence2HF, ocr_test_image_numpy: np.ndarray):
107+
# when
108+
result = florence2_model.parse_document(images=ocr_test_image_numpy)
109+
# then
110+
assert isinstance(result, list)
111+
assert len(result) == 1
112+
assert result[0].xyxy.shape[0] >= 1
113+
assert result[0].xyxy.shape[1] == 4
114+
full_text = "".join(
115+
meta["class_name"] for meta in result[0].bboxes_metadata
116+
).lstrip("</s>")
117+
assert full_text == "This is a test image for OCR."
118+
119+
120+
@pytest.mark.slow
121+
def test_ocr_image(florence2_model: Florence2HF, ocr_test_image_numpy: np.ndarray):
122+
# when
123+
result = florence2_model.ocr_image(images=ocr_test_image_numpy)
124+
# then
125+
assert result == ["This is a test image for OCR."]

0 commit comments

Comments
 (0)