Skip to content

Commit e432607

Browse files
authored
HuggingFace support and June updates (microsoft#2)
* Improve SQL select experience implementation * support pgvector * Feature: Add huggingface model/hyperparameters space suggestion * Improve packaging * Minor fix * Generalize prompt template and knowledge * Minor fix
1 parent b3c0e6a commit e432607

File tree

18 files changed

+421
-301
lines changed

18 files changed

+421
-301
lines changed

.github/workflows/ci.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@ name: Python CI
22

33
on:
44
push:
5-
branches: [ demo_orphan ]
5+
branches: [ main, dev ]
66
pull_request_target:
7-
branches: [ demo_orphan ]
7+
branches: [ main, dev ]
88

99
concurrency:
1010
group: ${{ format('ci-{0}', github.head_ref && format('pr-{0}', github.event.pull_request.number) || github.sha) }}
@@ -32,7 +32,7 @@ jobs:
3232
- name: Install dependencies
3333
run: |
3434
python -m pip install --upgrade pip
35-
pip install -r requirements.txt
35+
pip install -e .[dev]
3636
3737
- name: Lint with flake8
3838
run: flake8
@@ -61,7 +61,7 @@ jobs:
6161
- name: Install dependencies
6262
run: |
6363
python -m pip install --upgrade pip
64-
pip install -r requirements.txt
64+
pip install -e .[dev]
6565
- name: Test with pytest
6666
run: |
6767
pytest

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ MLCopilot is a tool to help you find the best models/hyperparametes for your tas
2020
0. Clone this repo: `git clone REPO_URL; cd mlcopilot`
2121
1. Put assets/mlcopilot.db in your home directory: `cp assets/mlcopilot.db ~/.mlcopilot/mlcopilot.db`
2222
2. Install Python 3.8 or higher
23-
3. Build: `hatch build`. (May need to install [hatch](https://hatch.pypa.io/latest/install/) first)
24-
4. Install: `pip install ./dist/*.whl`
23+
3. Install: `pip install .`. If you want to develop, use `pip install -e .[dev]` instead.
2524

2625
### Run
2726

assets/mlcopilot.db

1.9 MB
Binary file not shown.

mlcopilot/.env.template

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,17 @@
33
OPENAI_API_KEY=your-openai-api-key
44

55
### DB
6-
## MLCOPILOT_DB_PATH - Path to database file (Example: ~/.mlcopilot/mlcopilot.db)
7-
MLCOPILOT_DB_PATH=~/.mlcopilot/mlcopilot.db
6+
## MLCOPILOT_DB_BACKEND - Database backend (Example: sqlite)
7+
MLCOPILOT_DB_BACKEND=sqlite
8+
## MLCOPILOT_DB_PATH - Path to database file (Example: ~/.mlcopilot/mlcopilot.db) - Only for sqlite
9+
MLCOPILOT_DB_PATH=~/.mlcopilot/mlcopilot.db
10+
## MLCOPILOT_DB_NAME - Database name (Example: mlcopilot)
11+
MLCOPILOT_DB_NAME=mlcopilot
12+
## MLCOPILOT_DB_HOST - Database host (Example: localhost)
13+
MLCOPILOT_DB_HOST=localhost
14+
## MLCOPILOT_DB_PORT - Database port (Example: 5432)
15+
MLCOPILOT_DB_PORT=5432
16+
## MLCOPILOT_DB_USER - Database user (Example: postgres)
17+
MLCOPILOT_DB_USER=postgres
18+
## MLCOPILOT_DB_PASSWORD - Database password (Example: '')
19+
MLCOPILOT_DB_PASSWORD=''

mlcopilot/constants.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,38 @@
88
"bin_map",
99
"inverse_bin_map",
1010
"q_num",
11+
"MLCOPILOT_DB_BACKEND",
12+
"MLCOPILOT_DB_NAME",
13+
"MLCOPILOT_DB_HOST",
14+
"MLCOPILOT_DB_PORT",
15+
"MLCOPILOT_DB_USER",
16+
"MLCOPILOT_DB_PASSWORD",
17+
"PROMPT_FORMATS",
18+
"DEFAULT_PROMPT_PREFIX",
19+
"DEFAULT_PROMPT_SUFFIX",
20+
"TOKEN_LIMIT",
21+
"TOKEN_COMPLETION_LIMIT",
22+
"RELAX_TOKEN",
1123
]
1224

1325
TOP_K = 3
1426
EMBED_DIM = 1536
27+
TOKEN_LIMIT = 4096
28+
TOKEN_COMPLETION_LIMIT = 800
29+
RELAX_TOKEN = 500 # RELAX_TOKEN is the number of tokens to void token limit
1530

31+
MLCOPILOT_DB_BACKEND = os.environ.get("MLCOPILOT_DB_BACKEND", "sqlite")
1632

1733
MLCOPILOT_DB_PATH = Path(
1834
os.environ.get("MLCOPILOT_DB_PATH", Path.home() / ".mlcopilot" / "mlcopilot.db")
1935
).expanduser()
2036

37+
MLCOPILOT_DB_NAME = os.environ.get("MLCOPILOT_DB_NAME", "mlcopilot")
38+
MLCOPILOT_DB_HOST = os.environ.get("MLCOPILOT_DB_HOST", "localhost")
39+
MLCOPILOT_DB_PORT = os.environ.get("MLCOPILOT_DB_PORT", 5432)
40+
MLCOPILOT_DB_USER = os.environ.get("MLCOPILOT_DB_USER", "postgres")
41+
MLCOPILOT_DB_PASSWORD = os.environ.get("MLCOPILOT_DB_PASSWORD", "")
42+
2143
bin_map = {
2244
0.1: "very small",
2345
0.3: "small",
@@ -46,3 +68,17 @@
4668
)
4769

4870
q_num = sorted(list(bin_map.keys()))
71+
72+
PROMPT_FORMATS = {
73+
"TOP_K",
74+
"knowledge",
75+
"space_desc",
76+
"new_task_desc",
77+
}
78+
79+
DEFAULT_PROMPT_PREFIX = """{space_desc}\nRecommend best configurations to train a model for a new task. Format strictly follows this template: ```Configuration 1: {{parameter_1_name}} is {{parameter_1_value}}. {{parameter_2_name}} is {{parameter_2_value}}...{{parameter_n_name}} is {{parameter_n_value}}.
80+
Configuration 2: {{parameter_1_name}} is {{parameter_1_value}}. {{parameter_2_name}} is {{parameter_2_value}}...{{parameter_n_name}} is {{parameter_n_value}}.
81+
Configuration 3: {{parameter_1_name}} is {{parameter_1_value}}. {{parameter_2_name}} is {{parameter_2_value}}...{{parameter_n_name}} is {{parameter_n_value}}.
82+
```\nHere are some tasks along with best hyper-parameter configurations to train a model on them.\n"""
83+
84+
DEFAULT_PROMPT_SUFFIX = """\nGuidelines:{knowledge}\n\n\nBased on the examples(if provided) and guidelines(if provided) above, recommend {TOP_K} hyper-parameter configurations for a new classification dataset.\n\n{new_task_desc}"""

mlcopilot/experience.py

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from collections import OrderedDict
4-
from typing import Any, Dict, List, Optional
4+
from typing import Any, Dict, List, Optional, Tuple
55

66
import langchain
77
import numpy as np
@@ -330,13 +330,14 @@ def _get_best_relevant_solutions(space: Space, task_desc: str) -> ModelSelect:
330330
The best relevant solution.
331331
"""
332332
SolutionAlias = Solution.alias()
333+
order_key = Task.embedding.cosine_distance(task_desc)
333334
subquery = (
334335
SolutionAlias.select(
335336
SolutionAlias.demo,
336337
Task.task_id,
337338
Task.desc,
338339
Task.embedding,
339-
fn.RANK()
340+
fn.ROW_NUMBER()
340341
.over(
341342
partition_by=[SolutionAlias.space, SolutionAlias.task],
342343
order_by=[SolutionAlias.metric.desc()],
@@ -345,11 +346,13 @@ def _get_best_relevant_solutions(space: Space, task_desc: str) -> ModelSelect:
345346
)
346347
.where(SolutionAlias.space == space)
347348
.join(Task, on=(SolutionAlias.task == Task.task_id))
348-
.order_by(fn.cosine_similarity(task_desc, Task.embedding).desc())
349+
.order_by(order_key)
349350
.alias("subq")
350351
)
351-
query = Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc).from_(
352-
subquery
352+
query = (
353+
Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc)
354+
.from_(subquery)
355+
.where(subquery.c.rnk <= TOP_K)
353356
)
354357
return query
355358

@@ -375,7 +378,7 @@ def _get_best_solutions(space: Space) -> ModelSelect:
375378
Task.task_id,
376379
Task.desc,
377380
Task.embedding,
378-
fn.RANK()
381+
fn.ROW_NUMBER()
379382
.over(
380383
partition_by=[SolutionAlias.space, SolutionAlias.task],
381384
order_by=[SolutionAlias.metric.desc()],
@@ -386,13 +389,17 @@ def _get_best_solutions(space: Space) -> ModelSelect:
386389
.join(Task, on=(SolutionAlias.task == Task.task_id))
387390
.alias("subq")
388391
)
389-
query = Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc).from_(
390-
subquery
392+
query = (
393+
Solution.select(subquery.c.task_id, subquery.c.demo, subquery.c.desc)
394+
.from_(subquery)
395+
.where(subquery.c.rnk <= TOP_K)
391396
)
392397
return query
393398

394399

395-
def gen_experience(space: Space, task_desc: Optional[str] = None) -> List[str]:
400+
def gen_experience(
401+
space: Space, task_desc: Optional[str] = None
402+
) -> Tuple[List[str], List[str]]:
396403
"""
397404
Generate experience content from space and optional task description.
398405
@@ -417,8 +424,7 @@ def gen_experience(space: Space, task_desc: Optional[str] = None) -> List[str]:
417424
for solution in query:
418425
if solution.task_id not in examples:
419426
examples[solution.task_id] = [solution.desc]
420-
if len(examples[solution.task_id]) <= TOP_K:
421-
examples[solution.task_id].append(
422-
f"Configuration {len(examples[solution.task_id])}: {solution.demo}"
423-
)
424-
return ["\n".join(e) for e in examples.values()]
427+
examples[solution.task_id].append(
428+
f"Configuration {len(examples[solution.task_id])}: {solution.demo}"
429+
)
430+
return list(examples.keys()), ["\n".join(e) for e in examples.values()]

mlcopilot/knowledge.py

Lines changed: 82 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
import random
2+
import re
23
from typing import Any, Callable, Dict, List, Optional
34

45
import orjson
56
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
67
from langchain.prompts.example_selector import LengthBasedExampleSelector
78

89
from mlcopilot.constants import *
10+
from mlcopilot.constants import TOKEN_COMPLETION_LIMIT, TOKEN_LIMIT
911
from mlcopilot.experience import gen_experience
1012
from mlcopilot.orm import Knowledge, Solution, Space, Task, database_proxy
1113
from mlcopilot.surrogate_utils import evaluate_configs
12-
from mlcopilot.utils import get_llm, parse_configs
14+
from mlcopilot.utils import get_llm, get_token_count_func, parse_configs
1315

1416
prefix_sep = "__DUMM_SEP__"
1517

@@ -28,6 +30,12 @@ def gen_knowledge_candidate(examples: List[str]) -> str:
2830
str
2931
The generated knowledge candidate.
3032
"""
33+
prefix_token = get_token_count_func()(
34+
"Here are some tasks along with best hyper-parameter configurations to train a model on them.\n"
35+
)
36+
suffix_token = get_token_count_func()(
37+
"\nQ: From the examples above, what patterns can we observe about the relationship between dataset characteristics and the best hyper-parameter configurations? (Answer MUST be concise, critical, point-by-point, line-by-line, and brief. Only include relevant observations without unnecessary elaboration.)\n\nA: 1."
38+
)
3139
example_prompt = PromptTemplate(
3240
input_variables=["input"],
3341
template="{input}",
@@ -36,6 +44,12 @@ def gen_knowledge_candidate(examples: List[str]) -> str:
3644
example_selector = LengthBasedExampleSelector(
3745
examples=[{"input": example} for example in examples],
3846
example_prompt=example_prompt,
47+
max_length=TOKEN_LIMIT
48+
- prefix_token
49+
- suffix_token
50+
- TOKEN_COMPLETION_LIMIT
51+
- RELAX_TOKEN,
52+
get_text_length=get_token_count_func(),
3953
)
4054

4155
dynamic_prompt = FewShotPromptTemplate(
@@ -76,6 +90,18 @@ def suggest_with_knowledge(
7690
List[Dict[str, Any]]
7791
The list of suggested configurations.
7892
"""
93+
prefix_token = get_token_count_func()(
94+
"Here are some tasks along with best hyper-parameter configurations to train a model on them.\n"
95+
)
96+
suffix_token = get_token_count_func()(
97+
"\nGuidelines:{knowledge}\n\n\nBased on the examples and guidelines above, recommend {TOP_K} hyper-parameter configurations for a new classification dataset.\n\n{output}".format(
98+
knowledge=knowledge,
99+
TOP_K=str(TOP_K),
100+
output=(
101+
valid_example[: valid_example.index("\nConfiguration 1:")] + "\n\n"
102+
),
103+
)
104+
)
79105
example_prompt = PromptTemplate(
80106
input_variables=["input"],
81107
template="{input}",
@@ -84,6 +110,12 @@ def suggest_with_knowledge(
84110
example_selector = LengthBasedExampleSelector(
85111
examples=[{"input": example} for example in examples],
86112
example_prompt=example_prompt,
113+
max_length=TOKEN_LIMIT
114+
- prefix_token
115+
- suffix_token
116+
- TOKEN_COMPLETION_LIMIT
117+
- RELAX_TOKEN,
118+
get_text_length=get_token_count_func(),
87119
)
88120

89121
dynamic_prompt = FewShotPromptTemplate(
@@ -117,7 +149,7 @@ def suggest_with_knowledge(
117149

118150
def post_validation(
119151
space: Space, surrogate_fn: Callable, config_names: List[str]
120-
) -> str:
152+
) -> List[str]:
121153
"""
122154
Post validation to generate knowledge.
123155
@@ -132,17 +164,17 @@ def post_validation(
132164
133165
Returns
134166
-------
135-
str
136-
The generated knowledge.
167+
List[str]
168+
The list of generated knowledge.
137169
"""
138-
knowledge = get_knowledge(space.space_id)
139-
if knowledge is not None:
170+
knowledges = get_knowledge(space)
171+
if knowledges != "":
140172
print("Knowledge already exists.")
141-
return knowledge
173+
return knowledges
142174
quantile_infos = orjson.loads(space.quantile_info)
143-
examples = gen_experience(space)
175+
retrieved_tasks, examples = gen_experience(space)
144176
best_score = float("-inf")
145-
knowledge = None
177+
knowledges = None
146178
for _ in range(3):
147179
random.shuffle(examples)
148180
knowledge_candidate = gen_knowledge_candidate(examples)
@@ -168,15 +200,49 @@ def post_validation(
168200
score += _score
169201
if best_score < score:
170202
best_score = score
171-
knowledge = knowledge_candidate
172-
assert knowledge is not None, "Knowledge is not generated."
203+
knowledges = knowledge_candidate
204+
assert knowledges is not None, "Knowledge is not generated."
173205

174-
return knowledge
206+
knowledges = split_knowledge(knowledges)
207+
return knowledges
175208

176209

177-
def get_knowledge(space: Space):
210+
def get_knowledge(space: Space, task=None):
178211
try:
179-
knowledge = Knowledge.get(Knowledge.space_id == space.space_id).knowledge
180-
return knowledge
212+
knowledges = Knowledge.select().where(
213+
(Knowledge.space_id == space.space_id)
214+
& ((Knowledge.task == task) | (Knowledge.task == None))
215+
)
216+
knowledge_str = ""
217+
for i, knowledge in enumerate(knowledges):
218+
knowledge_str += f"{i+1}. {knowledge.knowledge}\n\n"
219+
return knowledge_str
181220
except:
182-
return None
221+
return ""
222+
223+
224+
def split_knowledge(knowledges: str) -> List[str]:
225+
"""
226+
Split the knowledge into a list of knowledge.
227+
228+
Parameters
229+
----------
230+
knowledges: str
231+
The knowledge.
232+
233+
Returns
234+
-------
235+
List[str]
236+
The list of knowledge.
237+
238+
Examples
239+
--------
240+
>>> split_knowledge("1. A\n2. B\n3. C\n")
241+
["A", "B", "C"]
242+
"""
243+
return [
244+
k.strip()
245+
for k in re.findall(
246+
r"\n\d+\.([\s\S]+?)(?=\n+\d+\.)", "\n" + knowledges + "\n999."
247+
)
248+
]

0 commit comments

Comments
 (0)