Skip to content

Commit 16b37ff

Browse files
sarangoTorantulino
sarango
authored andcommitted
Fix to LocalCache add method, created integration test for it
1 parent ae6adb4 commit 16b37ff

File tree

3 files changed

+50
-1
lines changed

3 files changed

+50
-1
lines changed

scripts/memory/local.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def add(self, text: str):
5454
vector = vector[np.newaxis, :]
5555
self.data.embeddings = np.concatenate(
5656
[
57-
vector,
5857
self.data.embeddings,
58+
vector,
5959
],
6060
axis=0,
6161
)

tests/integration/memory_tests.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import unittest
2+
import random
3+
import string
4+
import sys
5+
from pathlib import Path
6+
# Add the parent directory of the 'scripts' folder to the Python path
7+
sys.path.append(str(Path(__file__).resolve().parent.parent.parent / 'scripts'))
8+
from config import Config
9+
from memory.local import LocalCache
10+
11+
class TestLocalCache(unittest.TestCase):
12+
13+
def random_string(self, length):
14+
return ''.join(random.choice(string.ascii_letters) for _ in range(length))
15+
16+
def setUp(self):
17+
cfg = cfg = Config()
18+
self.cache = LocalCache(cfg)
19+
self.cache.clear()
20+
21+
# Add example texts to the cache
22+
self.example_texts = [
23+
'The quick brown fox jumps over the lazy dog',
24+
'I love machine learning and natural language processing',
25+
'The cake is a lie, but the pie is always true',
26+
'ChatGPT is an advanced AI model for conversation'
27+
]
28+
29+
for text in self.example_texts:
30+
self.cache.add(text)
31+
32+
# Add some random strings to test noise
33+
for _ in range(5):
34+
self.cache.add(self.random_string(10))
35+
36+
def test_get_relevant(self):
37+
query = "I'm interested in artificial intelligence and NLP"
38+
k = 3
39+
relevant_texts = self.cache.get_relevant(query, k)
40+
41+
print(f"Top {k} relevant texts for the query '{query}':")
42+
for i, text in enumerate(relevant_texts, start=1):
43+
print(f"{i}. {text}")
44+
45+
self.assertEqual(len(relevant_texts), k)
46+
self.assertIn(self.example_texts[1], relevant_texts)
47+
48+
if __name__ == '__main__':
49+
unittest.main()
File renamed without changes.

0 commit comments

Comments
 (0)