Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions docs/components/embedders/models/sambanova.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
---
title: SambaNova
---

To use SambaNova embedding models, set the `SAMBANOVA_API_KEY` environment variable. You can obtain the SambaNova API key from the [SambaNova Platform](https://cloud.sambanova.ai/apis).

### Usage

<Note> The `embedding_model_dims` parameter for `vector_store` should be set to `4096` for SambaNova embedder. </Note>

```python
import os
from mem0 import Memory

os.environ["SAMBANOVA_API_KEY"] = "your_api_key"
os.environ["OPENAI_API_KEY"] = "your_api_key" # For LLM

config = {
"embedder": {
"provider": "sambanova",
"config": {
"model": "gpt-oss-120b"
}
}
}

m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about thriller movies? They can be quite engaging."},
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="john")
```

### Config

Here are the parameters available for configuring SambaNova embedder:

| Parameter | Description | Default Value |
| --- | --- | --- |
| `model` | The name of the embedding model to use | `gpt-oss-120b` |
| `embedding_dims` | Dimensions of the embedding model | `768` |
| `api_key` | The SambaNova API key | `None` |
1 change: 1 addition & 0 deletions docs/components/embedders/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ See the list of supported embedders below.
<Card title="LM Studio" href="/components/embedders/models/lmstudio"></Card>
<Card title="Langchain" href="/components/embedders/models/langchain"></Card>
<Card title="AWS Bedrock" href="/components/embedders/models/aws_bedrock"></Card>
<Card title="SambaNova" href="/components/embedders/models/sambanova"></Card>
</CardGroup>

## Usage
Expand Down
39 changes: 39 additions & 0 deletions docs/components/llms/models/sambanova.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
---
title: SambaNova
---

To use SambaNova LLM models, you have to set the `SAMBANOVA_API_KEY` environment variable. You can obtain the SambaNova API key from their [Account settings page](https://cloud.sambanova.ai/apis).

## Usage

```python
import os
from mem0 import Memory

os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
os.environ["SAMBANOVA_API_KEY"] = "your-api-key"

config = {
"llm": {
"provider": "sambanova",
"config": {
"model": "gpt-oss-120b",
"temperature": 0.2,
"max_tokens": 2000,
}
}
}

m = Memory.from_config(config)
messages = [
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
{"role": "assistant", "content": "How about thriller movies? They can be quite engaging."},
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
]
m.add(messages, user_id="alice", metadata={"category": "movies"})
```

## Config

All available parameters for the `sambanova` config are present in [Master List of All Params in Config](../config).
1 change: 1 addition & 0 deletions docs/components/llms/overview.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ See the list of supported LLMs below.
<Card title="Sarvam AI" href="/components/llms/models/sarvam" />
<Card title="LM Studio" href="/components/llms/models/lmstudio" />
<Card title="Langchain" href="/components/llms/models/langchain" />
<Card title="SambaNova" href="/components/llms/models/sambanova" />
</CardGroup>

## Structured vs Unstructured Outputs
Expand Down
1 change: 1 addition & 0 deletions mem0/embeddings/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def validate_config(cls, v, values):
"langchain",
"aws_bedrock",
"fastembed",
"sambanova",
]:
return v
else:
Expand Down
30 changes: 30 additions & 0 deletions mem0/embeddings/sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
from typing import Literal, Optional

from sambanova import SambaNova

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.base import EmbeddingBase


class SambaNovaEmbedding(EmbeddingBase):
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
super().__init__(config)

self.config.model = self.config.model or "E5-Mistral-7B-Instruct"
api_key = self.config.api_key or os.getenv("SAMBANOVA_API_KEY")
self.config.embedding_dims = self.config.embedding_dims or 4096
self.client = SambaNova(api_key=api_key)

def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
"""
Get the embedding for the given text using SambaNova.

Args:
text (str): The text to embed.
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
Returns:
list: The embedding vector.
"""

return self.client.embeddings.create(model=self.config.model, input=text).data[0].embedding
1 change: 1 addition & 0 deletions mem0/llms/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def validate_config(cls, v, values):
"lmstudio",
"vllm",
"langchain",
"sambanova",
):
return v
else:
Expand Down
88 changes: 88 additions & 0 deletions mem0/llms/sambanova.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import json
import os
from typing import Dict, List, Optional

try:
from sambanova import SambaNova
except ImportError:
raise ImportError("The 'sambanova' library is required. Please install it using 'pip install sambanova'.")

from mem0.configs.llms.base import BaseLlmConfig
from mem0.llms.base import LLMBase
from mem0.memory.utils import extract_json


class SambaNovaLLM(LLMBase):
def __init__(self, config: Optional[BaseLlmConfig] = None):
super().__init__(config)

if not self.config.model:
self.config.model = "gpt-oss-120b"

api_key = self.config.api_key or os.getenv("SAMBANOVA_API_KEY")
self.client = SambaNova(api_key=api_key)

def _parse_response(self, response, tools):
"""
Process the response based on whether tools are used or not.

Args:
response: The raw response from API.
tools: The list of tools provided in the request.

Returns:
str or dict: The processed response.
"""
if tools:
processed_response = {
"content": response.choices[0].message.content,
"tool_calls": [],
}

if response.choices[0].message.tool_calls:
for tool_call in response.choices[0].message.tool_calls:
processed_response["tool_calls"].append(
{
"name": tool_call.function.name,
"arguments": json.loads(extract_json(tool_call.function.arguments)),
}
)

return processed_response
else:
return response.choices[0].message.content

def generate_response(
self,
messages: List[Dict[str, str]],
response_format=None,
tools: Optional[List[Dict]] = None,
tool_choice: str = "auto",
):
"""
Generate a response based on the given messages using SambaNova.

Args:
messages (list): List of message dicts containing 'role' and 'content'.
response_format (str or object, optional): Format of the response. Defaults to "text".
tools (list, optional): List of tools that the model can call. Defaults to None.
tool_choice (str, optional): Tool choice method. Defaults to "auto".

Returns:
str: The generated response.
"""
params = {
"model": self.config.model,
"messages": messages,
"temperature": self.config.temperature,
"max_tokens": self.config.max_tokens,
"top_p": self.config.top_p,
}
if response_format:
params["response_format"] = response_format
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice

response = self.client.chat.completions.create(**params)
return self._parse_response(response, tools)
10 changes: 7 additions & 3 deletions mem0/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from mem0.configs.llms.vllm import VllmConfig
from mem0.configs.rerankers.base import BaseRerankerConfig
from mem0.configs.rerankers.cohere import CohereRerankerConfig
from mem0.configs.rerankers.sentence_transformer import SentenceTransformerRerankerConfig
from mem0.configs.rerankers.zero_entropy import ZeroEntropyRerankerConfig
from mem0.configs.rerankers.llm import LLMRerankerConfig
from mem0.configs.rerankers.huggingface import HuggingFaceRerankerConfig
from mem0.configs.rerankers.llm import LLMRerankerConfig
from mem0.configs.rerankers.sentence_transformer import (
SentenceTransformerRerankerConfig,
)
from mem0.configs.rerankers.zero_entropy import ZeroEntropyRerankerConfig
from mem0.embeddings.mock import MockEmbeddings


Expand Down Expand Up @@ -50,6 +52,7 @@ class LlmFactory:
"lmstudio": ("mem0.llms.lmstudio.LMStudioLLM", LMStudioConfig),
"vllm": ("mem0.llms.vllm.VllmLLM", VllmConfig),
"langchain": ("mem0.llms.langchain.LangchainLLM", BaseLlmConfig),
"sambanova": ("mem0.llms.sambanova.SambaNovaLLM", BaseLlmConfig),
}

@classmethod
Expand Down Expand Up @@ -146,6 +149,7 @@ class EmbedderFactory:
"langchain": "mem0.embeddings.langchain.LangchainEmbedding",
"aws_bedrock": "mem0.embeddings.aws_bedrock.AWSBedrockEmbedding",
"fastembed": "mem0.embeddings.fastembed.FastEmbedEmbedding",
"sambanova": "mem0.embeddings.sambanova.SambaNovaEmbedding",
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ llms = [
"vertexai>=0.1.0",
"google-generativeai>=0.3.0",
"google-genai>=1.0.0",
"sambanova>=1.2.0",
]
extras = [
"boto3>=1.34.0",
Expand Down
60 changes: 60 additions & 0 deletions tests/embeddings/test_sambanova_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from unittest.mock import Mock, patch

import pytest

from mem0.configs.embeddings.base import BaseEmbedderConfig
from mem0.embeddings.sambanova import SambaNovaEmbedding


@pytest.fixture
def mock_sambanova_client():
with patch("mem0.embeddings.sambanova.SambaNova") as mock_sambanova:
mock_client = Mock()
mock_sambanova.return_value = mock_client
yield mock_client


def test_embed_default_model(mock_sambanova_client):
config = BaseEmbedderConfig()
embedder = SambaNovaEmbedding(config)
mock_response = Mock()
mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3])]
mock_sambanova_client.embeddings.create.return_value = mock_response

result = embedder.embed("Hello world")

mock_sambanova_client.embeddings.create.assert_called_once_with(
input="Hello world", model="E5-Mistral-7B-Instruct"
)
assert result == [0.1, 0.2, 0.3]


def test_embed_without_api_key_env_var(mock_sambanova_client):
config = BaseEmbedderConfig(api_key="test_key")
embedder = SambaNovaEmbedding(config)
mock_response = Mock()
mock_response.data = [Mock(embedding=[1.0, 1.1, 1.2])]
mock_sambanova_client.embeddings.create.return_value = mock_response

result = embedder.embed("Testing API key")

mock_sambanova_client.embeddings.create.assert_called_once_with(
input="Testing API key", model="E5-Mistral-7B-Instruct"
)
assert result == [1.0, 1.1, 1.2]


def test_embed_uses_environment_api_key(mock_sambanova_client, monkeypatch):
monkeypatch.setenv("SAMBANOVA_API_KEY", "env_key")
config = BaseEmbedderConfig()
embedder = SambaNovaEmbedding(config)
mock_response = Mock()
mock_response.data = [Mock(embedding=[1.3, 1.4, 1.5])]
mock_sambanova_client.embeddings.create.return_value = mock_response

result = embedder.embed("Environment key test")

mock_sambanova_client.embeddings.create.assert_called_once_with(
input="Environment key test", model="E5-Mistral-7B-Instruct"
)
assert result == [1.3, 1.4, 1.5]
Loading