|
1 | 1 | from math import sqrt
|
2 | 2 | import numpy as np
|
3 | 3 | from peewee import Model, PostgresqlDatabase, fn
|
4 |
| -from pgvector import SparseVector |
| 4 | +from pgvector import HalfVector, SparseVector |
5 | 5 | from pgvector.peewee import VectorField, HalfVectorField, FixedBitField, SparseVectorField
|
6 | 6 |
|
7 | 7 | db = PostgresqlDatabase('pgvector_python_test')
|
@@ -77,7 +77,7 @@ def test_vector_l1_distance(self):
|
77 | 77 | def test_halfvec(self):
|
78 | 78 | Item.create(id=1, half_embedding=[1, 2, 3])
|
79 | 79 | item = Item.get_by_id(1)
|
80 |
| - assert item.half_embedding.to_list() == [1, 2, 3] |
| 80 | + assert item.half_embedding == HalfVector([1, 2, 3]) |
81 | 81 |
|
82 | 82 | def test_halfvec_l2_distance(self):
|
83 | 83 | create_items()
|
@@ -129,7 +129,7 @@ def test_bit_jaccard_distance(self):
|
129 | 129 | def test_sparsevec(self):
|
130 | 130 | Item.create(id=1, sparse_embedding=[1, 2, 3])
|
131 | 131 | item = Item.get_by_id(1)
|
132 |
| - assert item.sparse_embedding.to_list() == [1, 2, 3] |
| 132 | + assert item.sparse_embedding == SparseVector([1, 2, 3]) |
133 | 133 |
|
134 | 134 | def test_sparsevec_l2_distance(self):
|
135 | 135 | create_items()
|
@@ -186,15 +186,15 @@ def test_halfvec_avg(self):
|
186 | 186 | Item.create(half_embedding=[1, 2, 3])
|
187 | 187 | Item.create(half_embedding=[4, 5, 6])
|
188 | 188 | avg = Item.select(fn.avg(Item.half_embedding).coerce(True)).scalar()
|
189 |
| - assert avg.to_list() == [2.5, 3.5, 4.5] |
| 189 | + assert avg == HalfVector([2.5, 3.5, 4.5]) |
190 | 190 |
|
191 | 191 | def test_halfvec_sum(self):
|
192 | 192 | sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
|
193 | 193 | assert sum is None
|
194 | 194 | Item.create(half_embedding=[1, 2, 3])
|
195 | 195 | Item.create(half_embedding=[4, 5, 6])
|
196 | 196 | sum = Item.select(fn.sum(Item.half_embedding).coerce(True)).scalar()
|
197 |
| - assert sum.to_list() == [5, 7, 9] |
| 197 | + assert sum == HalfVector([5, 7, 9]) |
198 | 198 |
|
199 | 199 | def test_get_or_create(self):
|
200 | 200 | Item.get_or_create(id=1, defaults={'embedding': [1, 2, 3]})
|
|
0 commit comments