Skip to content

feat: adding agents back to the experimental repo #326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
17 changes: 17 additions & 0 deletions haystack_experimental/components/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import sys
from typing import TYPE_CHECKING

from lazy_imports import LazyImporter

_import_structure = {"agent": ["Agent"], "state": ["State"]}

if TYPE_CHECKING:
from .agent import Agent
from .state import State

else:
sys.modules[__name__] = LazyImporter(name=__name__, module_file=__file__, import_structure=_import_structure)
458 changes: 458 additions & 0 deletions haystack_experimental/components/agents/agent.py

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions haystack_experimental/components/agents/state/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from .state import State
from .state_utils import merge_lists, replace_values

__all__ = ["State", "merge_lists", "replace_values"]
179 changes: 179 additions & 0 deletions haystack_experimental/components/agents/state/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional

from haystack.dataclasses import ChatMessage
from haystack.utils import _deserialize_value_with_schema, _serialize_value_with_schema
from haystack.utils.callable_serialization import deserialize_callable, serialize_callable
from haystack.utils.type_serialization import deserialize_type, serialize_type

from .state_utils import _is_list_type, _is_valid_type, merge_lists, replace_values


def _schema_to_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a schema dictionary to a serializable format.

Converts each parameter's type and optional handler function into a serializable
format using type and callable serialization utilities.

:param schema: Dictionary mapping parameter names to their type and handler configs
:returns: Dictionary with serialized type and handler information
"""
serialized_schema = {}
for param, config in schema.items():
serialized_schema[param] = {"type": serialize_type(config["type"])}
if config.get("handler"):
serialized_schema[param]["handler"] = serialize_callable(config["handler"])

return serialized_schema


def _schema_from_dict(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Convert a serialized schema dictionary back to its original format.

Deserializes the type and optional handler function for each parameter from their
serialized format back into Python types and callables.

:param schema: Dictionary containing serialized schema information
:returns: Dictionary with deserialized type and handler configurations
"""
deserialized_schema = {}
for param, config in schema.items():
deserialized_schema[param] = {"type": deserialize_type(config["type"])}

if config.get("handler"):
deserialized_schema[param]["handler"] = deserialize_callable(config["handler"])

return deserialized_schema


def _validate_schema(schema: Dict[str, Any]) -> None:
"""
Validate that a schema dictionary meets all required constraints.

Checks that each parameter definition has a valid type field and that any handler
specified is a callable function.

:param schema: Dictionary mapping parameter names to their type and handler configs
:raises ValueError: If schema validation fails due to missing or invalid fields
"""
for param, definition in schema.items():
if "type" not in definition:
raise ValueError(f"StateSchema: Key '{param}' is missing a 'type' entry.")
if not _is_valid_type(definition["type"]):
raise ValueError(f"StateSchema: 'type' for key '{param}' must be a Python type, got {definition['type']}")
if definition.get("handler") is not None and not callable(definition["handler"]):
raise ValueError(f"StateSchema: 'handler' for key '{param}' must be callable or None")
if param == "messages" and definition["type"] != List[ChatMessage]:
raise ValueError(f"StateSchema: 'messages' must be of type List[ChatMessage], got {definition['type']}")


class State:
"""
A class that wraps a StateSchema and maintains an internal _data dictionary.

Each schema entry has:
"parameter_name": {
"type": SomeType,
"handler": Optional[Callable[[Any, Any], Any]]
}
"""

def __init__(self, schema: Dict[str, Any], data: Optional[Dict[str, Any]] = None):
"""
Initialize a State object with a schema and optional data.

:param schema: Dictionary mapping parameter names to their type and handler configs.
Type must be a valid Python type, and handler must be a callable function or None.
If handler is None, the default handler for the type will be used. The default handlers are:
- For list types: `haystack.agents.state.state_utils.merge_lists`
- For all other types: `haystack.agents.state.state_utils.replace_values`
:param data: Optional dictionary of initial data to populate the state
"""
_validate_schema(schema)
self.schema = deepcopy(schema)
if self.schema.get("messages") is None:
self.schema["messages"] = {"type": List[ChatMessage], "handler": merge_lists}
self._data = data or {}

# Set default handlers if not provided in schema
for definition in self.schema.values():
# Skip if handler is already defined and not None
if definition.get("handler") is not None:
continue
# Set default handler based on type
if _is_list_type(definition["type"]):
definition["handler"] = merge_lists
else:
definition["handler"] = replace_values

def get(self, key: str, default: Any = None) -> Any:
"""
Retrieve a value from the state by key.

:param key: Key to look up in the state
:param default: Value to return if key is not found
:returns: Value associated with key or default if not found
"""
return deepcopy(self._data.get(key, default))

def set(self, key: str, value: Any, handler_override: Optional[Callable[[Any, Any], Any]] = None) -> None:
"""
Set or merge a value in the state according to schema rules.

Value is merged or overwritten according to these rules:
- if handler_override is given, use that
- else use the handler defined in the schema for 'key'

:param key: Key to store the value under
:param value: Value to store or merge
:param handler_override: Optional function to override the default merge behavior
"""
# If key not in schema, we throw an error
definition = self.schema.get(key, None)
if definition is None:
raise ValueError(f"State: Key '{key}' not found in schema. Schema: {self.schema}")

# Get current value from state and apply handler
current_value = self._data.get(key, None)
handler = handler_override or definition["handler"]
self._data[key] = handler(current_value, value)

@property
def data(self):
"""
All current data of the state.
"""
return self._data

def has(self, key: str) -> bool:
"""
Check if a key exists in the state.

:param key: Key to check for existence
:returns: True if key exists in state, False otherwise
"""
return key in self._data

def to_dict(self):
"""
Convert the State object to a dictionary.
"""
serialized = {}
serialized["schema"] = _schema_to_dict(self.schema)
serialized["data"] = _serialize_value_with_schema(self._data)
return serialized

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "State":
"""
Convert a dictionary back to a State object.
"""
schema = _schema_from_dict(data.get("schema", {}))
deserialized_data = _deserialize_value_with_schema(data.get("data", {}))
return State(schema, deserialized_data)
77 changes: 77 additions & 0 deletions haystack_experimental/components/agents/state/state_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# SPDX-FileCopyrightText: 2022-present deepset GmbH <[email protected]>
#
# SPDX-License-Identifier: Apache-2.0

import inspect
from typing import Any, List, TypeVar, Union, get_origin

T = TypeVar("T")


def _is_valid_type(obj: Any) -> bool:
"""
Check if an object is a valid type annotation.

Valid types include:
- Normal classes (str, dict, CustomClass)
- Generic types (List[str], Dict[str, int])
- Union types (Union[str, int], Optional[str])

:param obj: The object to check
:return: True if the object is a valid type annotation, False otherwise

Example usage:
>>> _is_valid_type(str)
True
>>> _is_valid_type(List[int])
True
>>> _is_valid_type(Union[str, int])
True
>>> _is_valid_type(42)
False
"""
# Handle Union types (including Optional)
if hasattr(obj, "__origin__") and obj.__origin__ == Union:
return True

# Handle normal classes and generic types
return inspect.isclass(obj) or type(obj).__name__ in {"_GenericAlias", "GenericAlias"}


def _is_list_type(type_hint: Any) -> bool:
"""
Check if a type hint represents a list type.

:param type_hint: The type hint to check
:return: True if the type hint represents a list, False otherwise
"""
return type_hint == list or (hasattr(type_hint, "__origin__") and get_origin(type_hint) == list)


def merge_lists(current: Union[List[T], T, None], new: Union[List[T], T]) -> List[T]:
"""
Merges two values into a single list.

If either `current` or `new` is not already a list, it is converted into one.
The function ensures that both inputs are treated as lists and concatenates them.

If `current` is None, it is treated as an empty list.

:param current: The existing value(s), either a single item or a list.
:param new: The new value(s) to merge, either a single item or a list.
:return: A list containing elements from both `current` and `new`.
"""
current_list = [] if current is None else current if isinstance(current, list) else [current]
new_list = new if isinstance(new, list) else [new]
return current_list + new_list


def replace_values(current: Any, new: Any) -> Any:
"""
Replace the `current` value with the `new` value.

:param current: The existing value
:param new: The new value to replace
:return: The new value
"""
return new
Loading