Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/remotes/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ requires-python = ">=3.11"
dev = [
"ether0",
"ether0.remotes[serve]",
"tensorboard>=2.19", # Indirect dependency we pin to keep recent
"tensorboard>=2.18", # Indirect dependency we pin to keep recent
Copy link
Collaborator Author

@jamesbraza jamesbraza Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I loosened the pinning here to allow for package resolution, as lighteval==0.10.0 requires numpy v1: huggingface/lighteval#416

]
serve = ["uvicorn"]

Expand Down
8 changes: 7 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ baselines = [
"ipython",
]
dev = [
"ether0[add-tokens,typing]",
"ether0[add-tokens,lighteval,typing]",
"huggingface-hub[cli]", # For login inside of CI
"ipython>=8", # Pin to keep recent
"mypy>=1.8", # For addition of mutable-override
Expand All @@ -69,6 +69,10 @@ dev = [
"refurb>=2", # Pin to keep recent
"typeguard",
]
lighteval = [
"aenum",
"lighteval[litellm]>=0.10", # Pin to keep recent
]
typing = [
"types-regex",
]
Expand Down Expand Up @@ -169,8 +173,10 @@ warn_unused_ignores = true
ignore_missing_imports = true
# Per-module configuration options
module = [
"aenum", # SEE: https://github.com/ethanfurman/aenum/issues/10
"datasets.*", # SEE: https://github.com/huggingface/datasets/issues/3841
"huggingface_hub.*", # SEE: https://github.com/huggingface/huggingface_hub/issues/1662
"lighteval.*", # SEE: https://github.com/huggingface/lighteval/issues/749
"molbloom", # SEE: https://github.com/whitead/molbloom/issues/29
"molsol", # SEE: https://github.com/maykcaldas/molsol/issues/6
"onmt.*",
Expand Down
182 changes: 182 additions & 0 deletions src/ether0/lighteval_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
import logging
import statistics
from collections.abc import Collection, Iterable
from typing import Any

from datasets import load_dataset

try:
from aenum import extend_enum
from lighteval.metrics.metrics import (
MetricCategory,
Metrics,
MetricUseCase,
SampleLevelMetric,
)
from lighteval.tasks.lighteval_task import LightevalTaskConfig
from lighteval.tasks.requests import Doc
except ImportError as exc:
raise ImportError(
"To use ether0's LightEval tasks, please install the 'lighteval' extra via:"
" `pip install ether0[lighteval]`."
) from exc

from ether0.data import get_problem_category
from ether0.model_prompts import LOOSE_XML_ANSWER_USER_PROMPT, ProblemPrompt
from ether0.models import make_problem_types_filter
from ether0.rewards import accuracy_reward, format_reward

logger = logging.getLogger(__name__)

ETHER0_ACCURACY_METRIC_NAME = "ether0_accuracy"
ETHER0_FORMAT_METRIC_NAME = "ether0_format"


def evaluate_ether0_accuracy(
predictions: list[str],
formatted_doc: Doc,
golds: list[str] | None = None, # noqa: ARG001
) -> float:
if len(predictions) != 1:
raise NotImplementedError(
"Didn't handle anything besides one prediction"
f" for doc {formatted_doc}, got {predictions}."
)
return accuracy_reward(
completions=predictions,
solution=[formatted_doc.specific["solution"]],
reasoning=formatted_doc.specific["reasoning"],
soft=formatted_doc.specific["soft"],
test=formatted_doc.specific["test"],
)[0]


def evaluate_ether0_format(
predictions: list[str],
formatted_doc: Doc,
golds: list[str] | None = None, # noqa: ARG001
) -> float:
if len(predictions) != 1:
raise NotImplementedError(
"Didn't handle anything besides one prediction"
f" for doc {formatted_doc}, got {predictions}."
)
if formatted_doc.specific["test"]:
logger.warning("ether0's format reward is only applicable at training time.")
return format_reward(
completions=predictions,
reasoning=formatted_doc.specific["reasoning"],
)[0]


for metric_name, metric_eval_fn in (
(ETHER0_ACCURACY_METRIC_NAME, evaluate_ether0_accuracy),
(ETHER0_FORMAT_METRIC_NAME, evaluate_ether0_format),
):
if ( # Work around https://github.com/huggingface/lighteval/issues/805
metric_name not in Metrics.__members__
):
extend_enum(
Metrics,
metric_name,
SampleLevelMetric(
metric_name=metric_name,
higher_is_better=True,
category=MetricCategory.GENERATIVE,
use_case=MetricUseCase.ACCURACY,
sample_level_fn=metric_eval_fn,
corpus_level_fn=statistics.mean,
),
)


KEYS_TO_STORE_IN_DOC = {"id", "solution"}


def make_ether0_task(
name: str,
soft: bool,
test: bool,
reasoning: bool,
problem_types: str | Collection[str] | None = None,
metric_names: Iterable[str] | None = None,
**kwargs,
) -> LightevalTaskConfig:
"""Create LightEval task for the ether0-benchmark dataset."""
reward_fn_kwargs = {"soft": soft, "test": test, "reasoning": reasoning}
if not test:
prob_prompt = ProblemPrompt.THINK_ANSWER if reasoning else ProblemPrompt.ANSWER
prompt_prefix: str = prob_prompt.get_prompt()
else:
prompt_prefix = LOOSE_XML_ANSWER_USER_PROMPT

def row_to_doc(row: dict[str, Any], task_name: str) -> Doc:
"""Convert an ether0-benchmark dataset row to a LightEval Doc."""
return Doc(
query="\n\n".join((prompt_prefix, row["problem"])),
task_name=task_name,
choices=[""], # Placeholder for non-QA tasks
gold_index=0, # Points to above placeholder
specific={k: row[k] for k in KEYS_TO_STORE_IN_DOC} | reward_fn_kwargs,
)

if metric_names is None:
metric_names = (
(ETHER0_ACCURACY_METRIC_NAME, ETHER0_FORMAT_METRIC_NAME)
if not test
else (ETHER0_ACCURACY_METRIC_NAME,)
)
return LightevalTaskConfig(
name=name,
prompt_function=row_to_doc,
suite=["community"],
hf_repo="futurehouse/ether0-benchmark",
hf_subset="default",
hf_filter=(
make_problem_types_filter(problem_types, type_col="problem_type")
if problem_types is not None
else None
),
hf_avail_splits=["test"],
evaluation_splits=["test"],
metric=[getattr(Metrics, metric_name) for metric_name in metric_names],
**kwargs,
)


# TASKS_TABLE is required by LightEval for --custom-tasks CLI arg
TASKS_TABLE = [ # Add general tasks
make_ether0_task(
f"ether0:{nickname}{':soft' if is_soft else ''}",
soft=is_soft,
test=kwargs["test"],
reasoning=kwargs["reasoning"],
)
for is_soft in (False, True)
for nickname, kwargs in (
("loose", {"test": True, "reasoning": False}),
("strict:no_reasoning", {"test": False, "reasoning": False}),
("strict", {"test": False, "reasoning": True}),
)
]
TASKS_TABLE.extend([ # Add problem type-specific tasks
make_ether0_task(
f"ether0:{nickname}{':soft' if is_soft else ''}:{prob_cat}",
soft=is_soft,
test=kwargs["test"],
reasoning=kwargs["reasoning"],
problem_types=f"re:^{prob_cat}.*$",
)
for is_soft in (False, True)
for nickname, kwargs in (
("loose", {"test": True, "reasoning": False}),
("strict:no_reasoning", {"test": False, "reasoning": False}),
("strict", {"test": False, "reasoning": True}),
)
for prob_cat in {
get_problem_category(pt)
for pt in load_dataset("futurehouse/ether0-benchmark", split="test")[
"problem_type"
]
}
])
52 changes: 35 additions & 17 deletions src/ether0/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import re
from collections.abc import Collection
from collections.abc import Callable, Collection, Mapping
from enum import StrEnum, auto
from typing import Any

Expand Down Expand Up @@ -104,39 +104,30 @@ class QAExample(BaseModel):
)


def filter_problem_types(
dataset: TDataset, problem_types: str | Collection[str] | None
) -> TDataset:
"""Filter a dataset by problem types.
def make_problem_types_filter(
problem_types: str | Collection[str], type_col: str
) -> Callable[[Mapping[str, Any]], bool]:
"""Make a filtration function to filter a dataset by problem types.

Args:
dataset: The dataset to filter. Can be a single Dataset or a DatasetDict.
problem_types: A string or collection of strings specifying the problem
types to filter by.
- If None, the original dataset is returned.
- If a string or a collection of strings:
- Strings starting with "re:" are treated as regex patterns.
If a regex filter is provided, then it must be the only filter.
- Strings starting with "!" are treated as problem types to exclude.
- Other strings are treated as exact problem types to include.
- Mixing inclusion and exclusion rules (e.g. ["type_a", "!type_b"])
is not allowed.
type_col: The column name in the dataset that contains the problem type.

Returns:
The filtered dataset.
Callable that returns True to keep a row, otherwise False to filter it out.
"""
if problem_types is None:
return dataset
if isinstance(problem_types, str): # Assume single problem type as a string
problem_types = [problem_types]
problem_types = {pt.strip() for pt in problem_types}

columns = (
next(iter(dataset.values())) if isinstance(dataset, DatasetDict) else dataset
).column_names
# ether0-benchmark uses 'problem_type'; some variants may use 'type'
type_col = "problem_type" if "problem_type" in columns else "type"

if any(pt.startswith("re:") for pt in problem_types):
# A regex was passed in
if len(problem_types) != 1:
Expand Down Expand Up @@ -170,4 +161,31 @@ def filter_func(x):
def filter_func(x):
return x[type_col] not in invalid_problem_types

return dataset.filter(filter_func, desc="Filtering problem types")
return filter_func


def filter_problem_types(
dataset: TDataset, problem_types: str | Collection[str] | None
) -> TDataset:
"""Filter a dataset by problem types.

Args:
dataset: The dataset to filter. Can be a single Dataset or a DatasetDict.
problem_types: See make_problem_types_filter.__doc__.

Returns:
The filtered dataset.
"""
if problem_types is None:
return dataset

columns = (
next(iter(dataset.values())) if isinstance(dataset, DatasetDict) else dataset
).column_names
# ether0-benchmark uses 'problem_type'; some variants may use 'type'
type_col = "problem_type" if "problem_type" in columns else "type"

return dataset.filter(
make_problem_types_filter(problem_types, type_col),
desc="Filtering problem types",
)
80 changes: 80 additions & 0 deletions tests/test_lighteval_tasks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from unittest.mock import patch

from lighteval.main_tasks import list as lighteval_list
from lighteval.metrics.metrics import Metrics, SampleLevelMetric
from lighteval.tasks.requests import Doc

import ether0.lighteval_tasks


def test_task_list(capsys) -> None:
"""Integration test designed to test TASKS_TABLE and custom task creation."""
with patch( # Work around https://github.com/huggingface/lighteval/issues/805
"lighteval.tasks.registry.create_custom_tasks_module",
side_effect=[ether0.lighteval_tasks],
):
lighteval_list(custom_tasks=ether0.lighteval_tasks.__file__)
captured = capsys.readouterr()
assert not captured.err
tasks = [row for row in captured.out.splitlines() if "ether0" in row]
assert len(tasks) > 1, "Expected some ether0 tasks"
assert any(
"functional-group" in row for row in tasks
), "Expected specific tasks to be listed"
# TODO: after https://github.com/huggingface/lighteval/issues/806,
# remove the .litellm_cache directory created by this test importing from LightEval


def test_accuracy_metric() -> None:
accuracy_metric = getattr(
Metrics, ether0.lighteval_tasks.ETHER0_ACCURACY_METRIC_NAME
).value
assert isinstance(accuracy_metric, SampleLevelMetric)

# NOTE: these inputs were taken from a gpt-4o baseline run
doc_json = {
"query": (
"When answering, be sure to place the final answer as SMILES notation into"
" XML tags <answer></answer>. An example is <answer>CCO</answer>.\n\nWhat"
" is a valid completion of this molecule:\nO=C(OCC1=CC=CC=C1)N1CCCC1C(=O"
),
"choices": [""],
"gold_index": 0,
"original_query": "",
"specific": {
"solution": (
"valid_mol_eval!:!O=C(OCC1=CC=CC=C1)N1CCCC1C(=O!:!molecule-completion"
),
"id": "e8b8bb34-731a-46e1-93a2-b6330a705148",
"soft": False,
"test": True,
"reasoning": False,
},
"task_name": "community|ether0:loose:molecule-completion",
"instruction": "",
"ctx": [{
"role": "user",
"content": (
"When answering, be sure to place the final answer as SMILES notation"
" into XML tags <answer></answer>. An example is"
" <answer>CCO</answer>.\n\nWhat is a valid completion of this"
" molecule:\nO=C(OCC1=CC=CC=C1)N1CCCC1C(=O"
),
}],
"num_asked_few_shots": 0,
"num_effective_few_shots": 0,
}
assert (
accuracy_metric.sample_level_fn(
predictions=[
"The given fragment of the molecule O=C(OCC1=CC=CC=C1)N1CCCC1C(=O suggests"
" a structure that indicates an amide linkage with a substituted"
" cyclohexanone. A plausible completion of this structure is a standard"
" cyclohexanone amide. Therefore, a valid SMILES notation for the completed"
" structure is:\n\n<answer>O=C(OCC1=CC=CC=C1)N1CCCC1C(=O)C2CCCCC2</answer>"
],
formatted_doc=Doc(**doc_json),
golds=[""],
)
== 1.0
)
Loading