Skip to content

Commit bb3b32c

Browse files
committed
Improved tests [skip ci]
1 parent 9f825f2 commit bb3b32c

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

tests/test_peewee.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from math import sqrt
22
import numpy as np
33
from peewee import Model, PostgresqlDatabase, fn
4-
from pgvector import SparseVector
4+
from pgvector import HalfVector, SparseVector
55
from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField
66

77
db = PostgresqlDatabase('pgvector_python_test')
@@ -77,7 +77,7 @@ def test_vector_l1_distance(self):
7777
def test_halfvec(self):
7878
Item.create(id=1, half_embedding=[1, 2, 3])
7979
item = Item.get_by_id(1)
80-
assert item.half_embedding.to_list() == [1, 2, 3]
80+
assert item.half_embedding == HalfVector([1, 2, 3])
8181

8282
def test_halfvec_l2_distance(self):
8383
create_items()
@@ -129,7 +129,7 @@ def test_bit_jaccard_distance(self):
129129
def test_sparsevec(self):
130130
Item.create(id=1, sparse_embedding=[1, 2, 3])
131131
item = Item.get_by_id(1)
132-
assert item.sparse_embedding.to_list() == [1, 2, 3]
132+
assert item.sparse_embedding == SparseVector([1, 2, 3])
133133

134134
def test_sparsevec_l2_distance(self):
135135
create_items()
@@ -186,15 +186,15 @@ def test_halfvec_avg(self):
186186
Item.create(half_embedding=[1, 2, 3])
187187
Item.create(half_embedding=[4, 5, 6])
188188
avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar()
189-
assert avg.to_list() == [2.5, 3.5, 4.5]
189+
assert avg == HalfVector([2.5, 3.5, 4.5])
190190

191191
def test_halfvec_sum(self):
192192
sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
193193
assert sum is None
194194
Item.create(half_embedding=[1, 2, 3])
195195
Item.create(half_embedding=[4, 5, 6])
196196
sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
197-
assert sum.to_list() == [5, 7, 9]
197+
assert sum == HalfVector([5, 7, 9])
198198

199199
def test_get_or_create(self):
200200
Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]})

0 commit comments

Comments
 (0)