Skip to content

Commit 1b25460

Browse files
committed
Raise error for unexpected dtype for Bit constructor [skip ci]
1 parent ac9fd53 commit 1b25460

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

pgvector/bit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ def __init__(self, value):
77
if isinstance(value, str):
88
self._value = self.from_text(value)._value
99
else:
10-
# TODO raise if dtype not bool or uint8
11-
if isinstance(value, np.ndarray) and value.dtype == np.uint8:
12-
value = np.unpackbits(value)
10+
if isinstance(value, np.ndarray):
11+
if value.dtype == np.uint8:
12+
value = np.unpackbits(value).astype(bool)
13+
elif value.dtype != np.bool:
14+
raise ValueError('expected dtype to be bool or uint8')
1315
else:
1416
value = np.asarray(value, dtype=bool)
1517

tests/test_bit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ def test_ndarray_uint8(self):
1717
arr = np.array([254, 7, 0], dtype=np.uint8)
1818
assert Bit(arr).to_text() == '111111100000011100000000'
1919

20+
def test_ndarray_uint16(self):
21+
arr = np.array([254, 7, 0], dtype=np.uint16)
22+
with pytest.raises(ValueError) as error:
23+
Bit(arr)
24+
assert str(error.value) == 'expected dtype to be bool or uint8'
25+
2026
def test_ndarray_same_object(self):
2127
arr = np.array([True, False, True])
2228
assert Bit(arr).to_list() == [True, False, True]

0 commit comments

Comments
 (0)