Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions agentstack/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,7 @@ def init_project_builder(

template_data = None
if template is not None:
if template.startswith("https://"):
try:
template_data = TemplateConfig.from_url(template)
except Exception as e:
raise Exception(f"Failed to fetch template data from {template}.\n{e}")
else:
try:
template_data = TemplateConfig.from_template_name(template)
except Exception as e:
raise Exception(f"Failed to load template {template}.\n{e}")
template_data = TemplateConfig.from_user_input(template)

if template_data:
project_details = {
Expand Down Expand Up @@ -115,7 +106,6 @@ def init_project_builder(


def welcome_message():
#os.system("cls" if os.name == "nt" else "clear")
title = text2art("AgentStack", font="smisome1")
tagline = "The easiest way to build a robust agent application!"
border = "-" * len(tagline)
Expand Down Expand Up @@ -400,7 +390,7 @@ def insert_template(
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env.example',
f'{template_path}/{"{{cookiecutter.project_metadata.project_slug}}"}/.env',
)

cookiecutter(str(template_path), no_input=True, extra_context=None)

# TODO: inits a git repo in the directory the command was run in
Expand Down
43 changes: 26 additions & 17 deletions agentstack/proj_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def to_v3(self) -> 'TemplateConfig':
framework=self.framework,
method=self.method,
manager_agent=None,
agents=[TemplateConfig.Agent(**agent.dict()) for agent in self.agents],
tasks=[TemplateConfig.Task(**task.dict()) for task in self.tasks],
tools=[TemplateConfig.Tool(**tool.dict()) for tool in self.tools],
agents=[TemplateConfig.Agent(**agent.model_dump()) for agent in self.agents],
tasks=[TemplateConfig.Task(**task.model_dump()) for task in self.tasks],
tools=[TemplateConfig.Tool(**tool.model_dump()) for tool in self.tools],
inputs=self.inputs,
)

Expand Down Expand Up @@ -144,17 +144,22 @@ def write_to_file(self, filename: Path):
f.write(json.dumps(model_dump, indent=4))

@classmethod
def from_template_name(cls, name: str) -> 'TemplateConfig':
# if url
if name.startswith('https://'):
return cls.from_url(name)

# if .json file
if name.endswith('.json'):
path = os.getcwd() / Path(name)
def from_user_input(cls, identifier: str):
"""
Load a template from a user-provided identifier.
Three cases will be tried: A URL, a file path, or a template name.
"""
if identifier.startswith('https://'):
return cls.from_url(identifier)

if identifier.endswith('.json'):
path = Path() / identifier
return cls.from_file(path)

# if named template
return cls.from_template_name(identifier)

@classmethod
def from_template_name(cls, name: str) -> 'TemplateConfig':
path = get_package_path() / f'templates/proj_templates/{name}.json'
if not name in get_all_template_names():
raise ValidationError(f"Template {name} not bundled with agentstack.")
Expand All @@ -164,8 +169,11 @@ def from_template_name(cls, name: str) -> 'TemplateConfig':
def from_file(cls, path: Path) -> 'TemplateConfig':
if not os.path.exists(path):
raise ValidationError(f"Template {path} not found.")
with open(path, 'r') as f:
return cls.from_json(json.load(f))
try:
with open(path, 'r') as f:
return cls.from_json(json.load(f))
except json.JSONDecodeError as e:
raise ValidationError(f"Error decoding template JSON.\n{e}")

@classmethod
def from_url(cls, url: str) -> 'TemplateConfig':
Expand All @@ -174,7 +182,10 @@ def from_url(cls, url: str) -> 'TemplateConfig':
response = requests.get(url)
if response.status_code != 200:
raise ValidationError(f"Failed to fetch template from {url}")
return cls.from_json(response.json())
try:
return cls.from_json(response.json())
except json.JSONDecodeError as e:
raise ValidationError(f"Error decoding template JSON.\n{e}")

@classmethod
def from_json(cls, data: dict) -> 'TemplateConfig':
Expand All @@ -193,8 +204,6 @@ def from_json(cls, data: dict) -> 'TemplateConfig':
for error in e.errors():
err_msg += f"{' '.join([str(loc) for loc in error['loc']])}: {error['msg']}\n"
raise ValidationError(err_msg)
except json.JSONDecodeError as e:
raise ValidationError(f"Error decoding template JSON.\n{e}")


def get_all_template_paths() -> list[Path]:
Expand Down
58 changes: 57 additions & 1 deletion tests/test_agents_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import importlib.resources
from pathlib import Path
from agentstack import conf
from agentstack.agents import AgentConfig, AGENTS_FILENAME
from agentstack.agents import AgentConfig, AGENTS_FILENAME, get_all_agent_names, get_all_agents
from agentstack.exceptions import ValidationError

BASE_PATH = Path(__file__).parent

Expand Down Expand Up @@ -83,3 +84,58 @@ def test_write_none_values(self):
llm:
"""
)

def test_yaml_error(self):
# Create an invalid YAML file
with open(self.project_dir / AGENTS_FILENAME, 'w') as f:
f.write("""
agent_name:
role: "This is a valid line"
invalid_yaml: "This line is missing a colon"
nested_key: "This will cause a YAML error"
""")

# Attempt to load the config, which should raise a ValidationError
with self.assertRaises(ValidationError) as context:
AgentConfig("agent_name")

def test_pydantic_validation_error(self):
# Create a YAML file with an invalid field type
with open(self.project_dir / AGENTS_FILENAME, 'w') as f:
f.write("""
agent_name:
role: "This is a valid role"
goal: "This is a valid goal"
backstory: "This is a valid backstory"
llm: 123 # This should be a string, not an integer
""")

# Attempt to load the config, which should raise a ValidationError
with self.assertRaises(ValidationError) as context:
AgentConfig("agent_name")

def test_get_all_agent_names(self):
shutil.copy(BASE_PATH / "fixtures/agents_max.yaml", self.project_dir / AGENTS_FILENAME)

agent_names = get_all_agent_names()
self.assertEqual(set(agent_names), {"agent_name", "second_agent_name"})
self.assertEqual(agent_names, ["agent_name", "second_agent_name"])

def test_get_all_agent_names_missing_file(self):
if os.path.exists(self.project_dir / AGENTS_FILENAME):
os.remove(self.project_dir / AGENTS_FILENAME)
non_existent_file_agent_names = get_all_agent_names()
self.assertEqual(non_existent_file_agent_names, [])

def test_get_all_agent_names_empty_file(self):
with open(self.project_dir / AGENTS_FILENAME, 'w') as f:
f.write("")

empty_agent_names = get_all_agent_names()
self.assertEqual(empty_agent_names, [])

def test_get_all_agents(self):
shutil.copy(BASE_PATH / "fixtures/agents_max.yaml", self.project_dir / AGENTS_FILENAME)

for agent in get_all_agents():
self.assertIsInstance(agent, AgentConfig)
62 changes: 61 additions & 1 deletion tests/test_inputs_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import unittest
from pathlib import Path
from agentstack import conf
from agentstack.inputs import InputsConfig
from agentstack.inputs import InputsConfig, get_inputs, add_input_for_run
from agentstack.exceptions import ValidationError

BASE_PATH = Path(__file__).parent

Expand All @@ -30,3 +31,62 @@ def test_maximal_input_config(self):
assert config['input_name'] == "This in an input"
assert config['input_name_2'] == "This is another input"
assert config.to_dict() == {'input_name': "This in an input", 'input_name_2': "This is another input"}

def test_yaml_error(self):
# Create an invalid YAML file
with open(self.project_dir / "src/config/inputs.yaml", 'w') as f:
f.write("""
input_name: "This is a valid line"
invalid_yaml: "This line is missing a colon"
nested_key: "This will cause a YAML error"
""")

# Attempt to load the config, which should raise a ValidationError
with self.assertRaises(ValidationError) as context:
InputsConfig()

def test_create_inputs_file_if_not_exists(self):
# Ensure the inputs file doesn't exist
inputs_file = self.project_dir / "src/config/inputs.yaml"
if inputs_file.exists():
inputs_file.unlink()

# Create an InputsConfig instance and set a value
with InputsConfig() as config:
config['test_key'] = 'test_value'

# Check that the file was created
self.assertTrue(inputs_file.exists())

def test_inputs_config_contains(self):
# Create an InputsConfig instance and set some values
with InputsConfig() as config:
config['existing_key'] = 'some_value'
config['another_key'] = 'another_value'

# Test the __contains__ method
self.assertTrue('existing_key' in config)
self.assertTrue('another_key' in config)
self.assertFalse('non_existing_key' in config)

def test_get_inputs(self):
# Set up some initial inputs
with InputsConfig() as config:
config['saved_key'] = 'saved_value'

# Test get_inputs without run inputs
inputs = get_inputs()
self.assertEqual(inputs['saved_key'], 'saved_value')

# Add a run input
add_input_for_run('run_key', 'run_value')

# Test get_inputs with run inputs
inputs = get_inputs()
self.assertEqual(inputs['saved_key'], 'saved_value')
self.assertEqual(inputs['run_key'], 'run_value')

# Test that run inputs override saved inputs
add_input_for_run('saved_key', 'overridden_value')
inputs = get_inputs()
self.assertEqual(inputs['saved_key'], 'overridden_value')
56 changes: 55 additions & 1 deletion tests/test_tasks_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import importlib.resources
from pathlib import Path
from agentstack import conf
from agentstack.tasks import TaskConfig, TASKS_FILENAME
from agentstack.tasks import TaskConfig, TASKS_FILENAME, get_all_task_names, get_all_tasks
from agentstack.exceptions import ValidationError

BASE_PATH = Path(__file__).parent

Expand Down Expand Up @@ -76,3 +77,56 @@ def test_write_none_values(self):
agent: >
"""
)

def test_yaml_error(self):
# Create an invalid YAML file
with open(self.project_dir / TASKS_FILENAME, 'w') as f:
f.write("""
task_name:
description: "This is a valid line"
invalid_yaml: "This line is missing a colon"
nested_key: "This will cause a YAML error"
""")

# Attempt to load the config, which should raise a ValidationError
with self.assertRaises(ValidationError) as context:
TaskConfig("task_name")

def test_pydantic_validation_error(self):
# Create a YAML file with an invalid field type
with open(self.project_dir / TASKS_FILENAME, 'w') as f:
f.write("""
task_name:
description: "This is a valid description"
expected_output: "This is a valid expected output"
agent: 123 # This should be a string, not an integer
""")

# Attempt to load the config, which should raise a ValidationError
with self.assertRaises(ValidationError) as context:
TaskConfig("task_name")

def test_get_all_task_names(self):
shutil.copy(BASE_PATH / "fixtures/tasks_max.yaml", self.project_dir / TASKS_FILENAME)

task_names = get_all_task_names()
self.assertEqual(set(task_names), {"task_name", "task_name_two"})
self.assertEqual(task_names, ["task_name", "task_name_two"])

def test_get_all_task_names_missing_file(self):
if os.path.exists(self.project_dir / TASKS_FILENAME):
os.remove(self.project_dir / TASKS_FILENAME)
non_existent_file_task_names = get_all_task_names()
self.assertEqual(non_existent_file_task_names, [])

def test_get_all_task_names_empty_file(self):
with open(self.project_dir / TASKS_FILENAME, 'w') as f:
f.write("")

empty_task_names = get_all_task_names()
self.assertEqual(empty_task_names, [])

def test_get_all_tasks(self):
shutil.copy(BASE_PATH / "fixtures/tasks_max.yaml", self.project_dir / TASKS_FILENAME)
for task in get_all_tasks():
self.assertIsInstance(task, TaskConfig)
Loading
Loading