Skip to content

TaskAdherence V2 prompt updates #41616

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
from ._evaluators._protected_material import ProtectedMaterialEvaluator
from ._evaluators._qa import QAEvaluator
from ._evaluators._response_completeness import ResponseCompletenessEvaluator
from ._evaluators._task_adherence import TaskAdherenceEvaluator
from ._evaluators._task_adherence import TaskAdherenceEvaluator as TaskAdherenceEvaluatorV2
from ._evaluators._task_adherence_old import TaskAdherenceEvaluator
from ._evaluators._relevance import RelevanceEvaluator
from ._evaluators._retrieval import RetrievalEvaluator
from ._evaluators._rouge import RougeScoreEvaluator, RougeType
Expand Down Expand Up @@ -68,6 +69,7 @@
"GroundednessProEvaluator",
"ResponseCompletenessEvaluator",
"TaskAdherenceEvaluator",
"TaskAdherenceEvaluatorV2",
"IntentResolutionEvaluator",
"RelevanceEvaluator",
"SimilarityEvaluator",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _is_openai_model_config(val: object) -> TypeGuard[OpenAIModelConfiguration]:


def parse_model_config_type(
model_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
model_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
) -> None:
if _is_aoi_model_config(model_config):
model_config["type"] = AZURE_OPENAI_TYPE
Expand All @@ -106,9 +106,9 @@ def parse_model_config_type(


def construct_prompty_model_config(
model_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
default_api_version: str,
user_agent: str,
model_config: Union[AzureOpenAIModelConfiguration, OpenAIModelConfiguration],
default_api_version: str,
user_agent: str,
) -> dict:
parse_model_config_type(model_config)

Expand All @@ -126,6 +126,7 @@ def construct_prompty_model_config(

return prompty_model_config


def is_onedp_project(azure_ai_project: AzureAIProject) -> bool:
"""Check if the Azure AI project is an OneDP project.
Expand All @@ -138,6 +139,7 @@ def is_onedp_project(azure_ai_project: AzureAIProject) -> bool:
return True
return False


def validate_azure_ai_project(o: object) -> AzureAIProject:
fields = {"subscription_id": str, "resource_group_name": str, "project_name": str}

Expand Down Expand Up @@ -230,7 +232,7 @@ def _validate_typed_dict(o: object, t: Type[T_TypedDict]) -> T_TypedDict:
k
for k in annotations
if (is_total and get_origin(annotations[k]) is not NotRequired)
or (not is_total and get_origin(annotations[k]) is Required)
or (not is_total and get_origin(annotations[k]) is Required)
}

missing_keys = required_keys - o.keys()
Expand Down Expand Up @@ -291,7 +293,8 @@ def validate_annotation(v: object, annotation: Union[str, type, object]) -> bool

return cast(T_TypedDict, o)

def check_score_is_valid(score: Union[str, float], min_score = 1, max_score = 5) -> bool:

def check_score_is_valid(score: Union[str, float], min_score=1, max_score=5) -> bool:
"""Check if the score is valid, i.e. is convertable to number and is in the range [min_score, max_score].
:param score: The score to check.
Expand All @@ -310,6 +313,7 @@ def check_score_is_valid(score: Union[str, float], min_score = 1, max_score = 5)

return min_score <= numeric_score <= max_score


def parse_quality_evaluator_reason_score(llm_output: str, valid_score_range: str = "[1-5]") -> Tuple[float, str]:
"""Parse the output of prompt-based quality evaluators that return a score and reason.
Expand Down Expand Up @@ -422,11 +426,11 @@ def raise_exception(msg, target):
except ImportError as ex:
raise MissingRequiredPackage(
message="Please install 'azure-ai-inference' package to use SystemMessage, "
"UserMessage or AssistantMessage."
"UserMessage or AssistantMessage."
) from ex

if isinstance(message, ChatRequestMessage) and not isinstance(
message, (UserMessage, AssistantMessage, SystemMessage)
message, (UserMessage, AssistantMessage, SystemMessage)
):
raise_exception(
f"Messages must be a strongly typed class of ChatRequestMessage. Message number: {num}",
Expand All @@ -437,7 +441,7 @@ def raise_exception(msg, target):
if isinstance(message, UserMessage):
user_message_count += 1
if isinstance(message.content, list) and any(
isinstance(item, ImageContentItem) for item in message.content
isinstance(item, ImageContentItem) for item in message.content
):
image_found = True
continue
Expand Down Expand Up @@ -481,21 +485,26 @@ def raise_exception(msg, target):
ErrorTarget.CONTENT_SAFETY_CHAT_EVALUATOR,
)


def _extract_text_from_content(content):
text = []
for msg in content:
if 'text' in msg:
text.append(msg['text'])
return text

def _get_conversation_history(query):

def _get_conversation_history(query, include_system_messages=False):
all_user_queries = []
cur_user_query = []
all_agent_responses = []
cur_agent_response = []
system_message = None
for msg in query:
if not 'role' in msg:
continue
if include_system_messages and msg['role'] == 'system' and 'content' in msg:
system_message = msg.get('content', '')
if msg['role'] == 'user' and 'content' in msg:
if cur_agent_response != []:
all_agent_responses.append(cur_agent_response)
Expand All @@ -505,15 +514,15 @@ def _get_conversation_history(query):
cur_user_query.append(text_in_msg)

if msg['role'] == 'assistant' and 'content' in msg:
if cur_user_query !=[]:
if cur_user_query != []:
all_user_queries.append(cur_user_query)
cur_user_query = []
text_in_msg = _extract_text_from_content(msg['content'])
if text_in_msg:
cur_agent_response.append(text_in_msg)
if cur_user_query !=[]:
if cur_user_query != []:
all_user_queries.append(cur_user_query)
if cur_agent_response !=[]:
if cur_agent_response != []:
all_agent_responses.append(cur_agent_response)

if len(all_user_queries) != len(all_agent_responses) + 1:
Expand All @@ -524,31 +533,37 @@ def _get_conversation_history(query):
category=ErrorCategory.INVALID_VALUE,
blame=ErrorBlame.USER_ERROR,
)

return {
'user_queries' : all_user_queries,
'agent_responses' : all_agent_responses
}
'system_message': system_message,
'user_queries': all_user_queries,
'agent_responses': all_agent_responses
}


def _pretty_format_conversation_history(conversation_history):
"""Formats the conversation history for better readability."""
formatted_history = ""
for i, (user_query, agent_response) in enumerate(zip(conversation_history['user_queries'], conversation_history['agent_responses']+[None])):
formatted_history+=f"User turn {i+1}:\n"
if 'system_message' in conversation_history and conversation_history['system_message'] is not None:
formatted_history += "SYSTEM MESSAGE:\n"
formatted_history += " " + conversation_history['system_message'] + "\n\n"
for i, (user_query, agent_response) in enumerate(
zip(conversation_history['user_queries'], conversation_history['agent_responses'] + [None])):
formatted_history += f"User turn {i + 1}:\n"
for msg in user_query:
formatted_history+=" " + "\n ".join(msg)
formatted_history+="\n\n"
formatted_history += " " + "\n ".join(msg)
formatted_history += "\n\n"
if agent_response:
formatted_history+=f"Agent turn {i+1}:\n"
formatted_history += f"Agent turn {i + 1}:\n"
for msg in agent_response:
formatted_history+=" " + "\n ".join(msg)
formatted_history+="\n\n"
formatted_history += " " + "\n ".join(msg)
formatted_history += "\n\n"
return formatted_history

def reformat_conversation_history(query, logger = None):

def reformat_conversation_history(query, logger=None, include_system_messages=False):
"""Reformats the conversation history to a more compact representation."""
try:
conversation_history = _get_conversation_history(query)
conversation_history = _get_conversation_history(query, include_system_messages=include_system_messages)
return _pretty_format_conversation_history(conversation_history)
except:
Copy link
Preview

Copilot AI Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using a bare except which can hide unexpected errors; catch specific exceptions (e.g., ValueError, KeyError) and log the exception to aid debugging.

Copilot uses AI. Check for mistakes.

# If the conversation history cannot be parsed for whatever reason (e.g. the converter format changed), the original query is returned
Expand All @@ -562,25 +577,59 @@ def reformat_conversation_history(query, logger = None):
logger.warning(f"Conversation history could not be parsed, falling back to original query: {query}")
return query
Copy link
Preview

Copilot AI Jun 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fallback returns the original message list rather than a string, which may break downstream prompts that expect a formatted string. Consider serializing or stringifying the original query for consistency.

Copilot uses AI. Check for mistakes.


def _get_agent_response(agent_response_msgs):
"""Extracts the text from the agent response content."""

def _get_agent_response(agent_response_msgs, include_tool_messages=False):
"""Extracts formatted agent response including text, and optionally tool calls/results."""
agent_response_text = []
tool_results = {}

# First pass: collect tool results
if include_tool_messages:
for msg in agent_response_msgs:
if msg.get("role") == "tool" and "tool_call_id" in msg:
for content in msg.get("content", []):
if content.get("type") == "tool_result":
result = content.get("tool_result")
tool_results[msg["tool_call_id"]] = f'[TOOL_RESULT] {result}'

# Second pass: parse assistant messages and tool calls
for msg in agent_response_msgs:
if 'role' in msg and msg['role'] == 'assistant' and 'content' in msg:
text = _extract_text_from_content(msg['content'])
if 'role' in msg and msg.get("role") == "assistant" and "content" in msg:
text = _extract_text_from_content(msg["content"])
if text:
agent_response_text.extend(text)
if include_tool_messages:
for content in msg.get("content", []):
# Todo: Verify if this is the correct way to handle tool calls
if content.get("type") == "tool_call":
if "tool_call" in content:
tc = content.get("tool_call", {})
func_name = tc.get("function", {}).get("name", "")
args = tc.get("function", {}).get("arguments", {})
tool_call_id = tc.get("id")
else:
tool_call_id = content.get("tool_call_id")
func_name = content.get("name", "")
args = content.get("arguments", {})
args_str = ", ".join(f'{k}="{v}"' for k, v in args.items())
call_line = f'[TOOL_CALL] {func_name}({args_str})'
agent_response_text.append(call_line)
if tool_call_id in tool_results:
agent_response_text.append(tool_results[tool_call_id])

return agent_response_text

def reformat_agent_response(response, logger = None):

def reformat_agent_response(response, logger=None, include_tool_messages=False):
try:
if response is None or response == []:
return ""
agent_response = _get_agent_response(response)
agent_response = _get_agent_response(response, include_tool_messages=include_tool_messages)
if agent_response == []:
# If no message could be extracted, likely the format changed, fallback to the original response in that case
if logger:
logger.warning(f"Empty agent response extracted, likely due to input schema change. Falling back to using the original response: {response}")
logger.warning(
f"Empty agent response extracted, likely due to input schema change. Falling back to using the original response: {response}")
return response
return "\n".join(agent_response)
except:
Expand All @@ -590,6 +639,18 @@ def reformat_agent_response(response, logger = None):
logger.warning(f"Agent response could not be parsed, falling back to original response: {response}")
return response


def reformat_tool_definitions(tool_definitions, logger=None):
output_lines = ["TOOL DEFINITIONS:"]
for tool in tool_definitions:
name = tool.get("name", "unnamed_tool")
desc = tool.get("description", "").strip()
params = tool.get("parameters", {}).get("properties", {})
param_names = ", ".join(params.keys()) if params else "no parameters"
output_lines.append(f"- {name}: {desc} (inputs: {param_names})")
return "\n".join(output_lines)


def upload(path: str, container_client: ContainerClient, logger=None):
"""Upload files or directories to Azure Blob Storage using a container client.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,19 @@
# ---------------------------------------------------------
import os
import math
import logging
from typing import Dict, Union, List, Optional

from typing_extensions import overload, override

from azure.ai.evaluation._exceptions import EvaluationException, ErrorBlame, ErrorCategory, ErrorTarget
from azure.ai.evaluation._evaluators._common import PromptyEvaluatorBase
from azure.ai.evaluation._common.utils import parse_quality_evaluator_reason_score
from ..._common.utils import reformat_conversation_history, reformat_agent_response, reformat_tool_definitions
from azure.ai.evaluation._model_configurations import Message
from azure.ai.evaluation._common._experimental import experimental

logger = logging.getLogger(__name__)

@experimental
class TaskAdherenceEvaluator(PromptyEvaluatorBase[Union[str, float]]):
"""The Task Adherence evaluator assesses how well an AI-generated response follows the assigned task based on:
Expand Down Expand Up @@ -56,7 +59,7 @@ class TaskAdherenceEvaluator(PromptyEvaluatorBase[Union[str, float]]):
"""

_PROMPTY_FILE = "task_adherence.prompty"
_RESULT_KEY = "task_adherence"
_RESULT_KEY = "task_adherence_v2"
_OPTIONAL_PARAMS = ["tool_definitions"]

_DEFAULT_TASK_ADHERENCE_SCORE = 3
Expand Down Expand Up @@ -142,21 +145,23 @@ async def _do_eval(self, eval_input: Dict) -> Dict[str, Union[float, str]]: # t
category=ErrorCategory.MISSING_FIELD,
target=ErrorTarget.TASK_ADHERENCE_EVALUATOR,
)

eval_input['query'] = reformat_conversation_history(eval_input["query"], logger, include_system_messages=True)
eval_input['response'] = reformat_agent_response(eval_input["response"], logger, include_tool_messages=True)
if "tool_definitions" in eval_input and eval_input["tool_definitions"] is not None:
eval_input['tool_definitions'] = reformat_tool_definitions(eval_input["tool_definitions"], logger)
llm_output = await self._flow(timeout=self._LLM_CALL_TIMEOUT, **eval_input)

score = math.nan
if llm_output:
score, reason = parse_quality_evaluator_reason_score(llm_output, valid_score_range="[1-5]")

score_result = 'pass' if score >= self.threshold else 'fail'

if isinstance(llm_output, dict):
score = float(llm_output.get("score", math.nan))
score_result = "pass" if score >= self.threshold else "fail"
reason = llm_output.get("explanation", "")
return {
f"{self._result_key}": score,
f"{self._result_key}_result": score_result,
f"{self._result_key}_threshold": self.threshold,
f"{self._result_key}_reason": reason,
f"{self._result_key}_additional_details": llm_output
}

if logger:
logger.warning("LLM output is not a dictionary, returning NaN for the score.")
return {self._result_key: math.nan}

Loading
Loading