Skip to content

Commit 3750525

Browse files
Adding log table to log dictionary or df as artifacts in MLflow run (mlflow#8467)
Signed-off-by: Sunish Sheth <[email protected]>
1 parent 6110956 commit 3750525

File tree

5 files changed

+277
-0
lines changed

5 files changed

+277
-0
lines changed

mlflow/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@
145145
log_dict,
146146
log_image,
147147
log_figure,
148+
log_table,
148149
active_run,
149150
get_run,
150151
start_run,
@@ -200,6 +201,7 @@
200201
"log_text",
201202
"log_dict",
202203
"log_figure",
204+
"log_table",
203205
"log_image",
204206
"active_run",
205207
"start_run",

mlflow/tracking/client.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
and model versions. This is a lower level API than the :py:mod:`mlflow.tracking.fluent` module,
44
and is exposed in the :py:mod:`mlflow.tracking` module.
55
"""
6+
import mlflow
67
import contextlib
78
import logging
89
import json
@@ -41,8 +42,11 @@
4142
_validate_model_alias_name,
4243
_validate_model_version,
4344
)
45+
from mlflow.utils.mlflow_tags import MLFLOW_LOGGED_ARTIFACTS
46+
from mlflow.utils.annotations import experimental
4447

4548
if TYPE_CHECKING:
49+
import pandas # pylint: disable=unused-import
4650
import matplotlib # pylint: disable=unused-import
4751
import plotly # pylint: disable=unused-import
4852
import numpy # pylint: disable=unused-import
@@ -1386,6 +1390,102 @@ def _normalize_to_uint8(x):
13861390
else:
13871391
raise TypeError("Unsupported image object type: '{}'".format(type(image)))
13881392

1393+
@experimental
1394+
def log_table(
1395+
self,
1396+
run_id: str,
1397+
data: Union[Dict[str, Any], "pandas.DataFrame"],
1398+
artifact_file: str,
1399+
) -> None:
1400+
"""
1401+
Log a table to MLflow Tracking as a JSON artifact. If the artifact_file already exists
1402+
in the run, the data would be appended to the existing artifact_file.
1403+
1404+
:param run_id: String ID of the run.
1405+
:param data: Dictionary or pandas.DataFrame to log.
1406+
:param artifact_file: The run-relative artifact file path in posixpath format to which
1407+
the table is saved (e.g. "dir/file.json").
1408+
:return: None
1409+
1410+
.. test-code-block:: python
1411+
:caption: Dictionary Example
1412+
1413+
import mlflow
1414+
from mlflow import MlflowClient
1415+
1416+
table_dict = {
1417+
"inputs": ["What is MLflow?", "What is Databricks?"],
1418+
"outputs": ["MLflow is ...", "Databricks is ..."],
1419+
"toxicity": [0.0, 0.0],
1420+
}
1421+
1422+
client = MlflowClient()
1423+
run = client.create_run(experiment_id="0")
1424+
client.log_table(
1425+
run.info.run_id, data=table_dict, artifact_file="qabot_eval_results.json"
1426+
)
1427+
1428+
.. test-code-block:: python
1429+
:caption: Pandas DF Example
1430+
1431+
import mlflow
1432+
import pandas as pd
1433+
from mlflow import MlflowClient
1434+
1435+
table_dict = {
1436+
"inputs": ["What is MLflow?", "What is Databricks?"],
1437+
"outputs": ["MLflow is ...", "Databricks is ..."],
1438+
"toxicity": [0.0, 0.0],
1439+
}
1440+
df = pd.DataFrame.from_dict(table_dict)
1441+
1442+
client = MlflowClient()
1443+
run = client.create_run(experiment_id="0")
1444+
client.log_table(run.info.run_id, data=df, artifact_file="qabot_eval_results.json")
1445+
1446+
"""
1447+
import pandas as pd
1448+
1449+
if not isinstance(data, (pd.DataFrame, dict)):
1450+
raise MlflowException.invalid_parameter_value(
1451+
"data must be a pandas.DataFrame or a dictionary"
1452+
)
1453+
1454+
data = pd.DataFrame(data)
1455+
with tempfile.TemporaryDirectory() as tmpdir:
1456+
norm_path = posixpath.normpath(artifact_file)
1457+
artifact_dir = posixpath.dirname(norm_path)
1458+
artifact_dir = None if artifact_dir == "" else artifact_dir
1459+
1460+
artifacts = [f.path for f in self.list_artifacts(run_id, path=artifact_dir)]
1461+
if artifact_file in artifacts:
1462+
downloaded_artifact_path = mlflow.artifacts.download_artifacts(
1463+
run_id=run_id, artifact_path=artifact_file, dst_path=tmpdir
1464+
)
1465+
existing_predictions = pd.read_json(downloaded_artifact_path, orient="split")
1466+
data = pd.concat([existing_predictions, data], ignore_index=True)
1467+
_logger.info(
1468+
"Appending new table to already existing artifact "
1469+
f"{artifact_file} for run {run_id}."
1470+
)
1471+
else:
1472+
_logger.info(f"Creating a new {artifact_file} for run {run_id}.")
1473+
1474+
with self._log_artifact_helper(run_id, artifact_file) as artifact_path:
1475+
data.to_json(artifact_path, orient="split", index=False)
1476+
1477+
run = self.get_run(run_id)
1478+
1479+
# Get the current value of the tag
1480+
current_tag_value = json.loads(run.data.tags.get(MLFLOW_LOGGED_ARTIFACTS, "[]"))
1481+
tag_value = {"path": artifact_file, "type": "table"}
1482+
1483+
# Append the new tag value to the list if one doesn't exists
1484+
if tag_value not in current_tag_value:
1485+
current_tag_value.append(tag_value)
1486+
# Set the tag with the updated list
1487+
self.set_tag(run_id, MLFLOW_LOGGED_ARTIFACTS, json.dumps(current_tag_value))
1488+
13891489
def _record_logged_model(self, run_id, mlflow_model):
13901490
"""
13911491
Record logged model info with the tracking server.

mlflow/tracking/fluent.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
from mlflow.utils.validation import _validate_run_id, _validate_experiment_id_type
4242
from mlflow.utils.time_utils import get_current_time_millis
4343
from mlflow.utils.databricks_utils import is_in_databricks_runtime
44+
from mlflow.utils.annotations import experimental
4445

4546

4647
if TYPE_CHECKING:
@@ -988,6 +989,56 @@ def log_image(image: Union["numpy.ndarray", "PIL.Image.Image"], artifact_file: s
988989
MlflowClient().log_image(run_id, image, artifact_file)
989990

990991

992+
@experimental
993+
def log_table(
994+
data: Union[Dict[str, Any], "pandas.DataFrame"],
995+
artifact_file: str,
996+
) -> None:
997+
"""
998+
Log a table to MLflow Tracking as a JSON artifact. If the artifact_file already exists
999+
in the run, the data would be appended to the existing artifact_file.
1000+
1001+
:param data: Dictionary or pandas.DataFrame to log.
1002+
:param artifact_file: The run-relative artifact file path in posixpath format to which
1003+
the table is saved (e.g. "dir/file.json").
1004+
:return: None
1005+
1006+
.. test-code-block:: python
1007+
:caption: Dictionary Example
1008+
1009+
import mlflow
1010+
1011+
table_dict = {
1012+
"inputs": ["What is MLflow?", "What is Databricks?"],
1013+
"outputs": ["MLflow is ...", "Databricks is ..."],
1014+
"toxicity": [0.0, 0.0],
1015+
}
1016+
1017+
with mlflow.start_run():
1018+
# Log the dictionary as a table
1019+
mlflow.log_table(data=table_dict, artifact_file="qabot_eval_results.json")
1020+
1021+
.. test-code-block:: python
1022+
:caption: Pandas DF Example
1023+
1024+
import mlflow
1025+
import pandas as pd
1026+
1027+
table_dict = {
1028+
"inputs": ["What is MLflow?", "What is Databricks?"],
1029+
"outputs": ["MLflow is ...", "Databricks is ..."],
1030+
"toxicity": [0.0, 0.0],
1031+
}
1032+
df = pd.DataFrame.from_dict(table_dict)
1033+
1034+
with mlflow.start_run():
1035+
# Log the df as a table
1036+
mlflow.log_table(data=df, artifact_file="qabot_eval_results.json")
1037+
"""
1038+
run_id = _get_or_start_run().info.run_id
1039+
MlflowClient().log_table(run_id, data, artifact_file)
1040+
1041+
9911042
def _record_logged_model(mlflow_model):
9921043
run_id = _get_or_start_run().info.run_id
9931044
MlflowClient()._record_logged_model(run_id, mlflow_model)

mlflow/utils/mlflow_tags.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
MLFLOW_DOCKER_IMAGE_ID = "mlflow.docker.image.id"
2727
# Indicates that an MLflow run was created by an autologging integration
2828
MLFLOW_AUTOLOGGING = "mlflow.autologging"
29+
# Indicates the artifacts type and path that are logged
30+
MLFLOW_LOGGED_ARTIFACTS = "mlflow.loggedArtifacts"
2931

3032
MLFLOW_DATABRICKS_NOTEBOOK_ID = "mlflow.databricks.notebookID"
3133
MLFLOW_DATABRICKS_NOTEBOOK_PATH = "mlflow.databricks.notebookPath"

tests/tracking/test_tracking.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import ast
12
import pathlib
23
from collections import namedtuple
34
import filecmp
@@ -847,3 +848,124 @@ def test_search_runs_multiple_experiments():
847848
assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_1 > 0", ViewType.ALL)) == 1
848849
assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_2 = 2", ViewType.ALL)) == 1
849850
assert len(MlflowClient().search_runs(experiment_ids, "metrics.m_3 < 4", ViewType.ALL)) == 1
851+
852+
853+
@pytest.mark.skipif(
854+
"MLFLOW_SKINNY" in os.environ,
855+
reason="Skinny client does not support the np or pandas dependencies",
856+
)
857+
def test_log_table():
858+
import pandas as pd
859+
860+
table_dict = {
861+
"inputs": ["What is MLflow?", "What is Databricks?"],
862+
"outputs": ["MLflow is ...", "Databricks is ..."],
863+
"toxicity": [0.0, 0.0],
864+
}
865+
artifact_file = "qabot_eval_results.json"
866+
TAG_NAME = "mlflow.loggedArtifacts"
867+
run_id = None
868+
869+
with pytest.raises(
870+
MlflowException, match="data must be a pandas.DataFrame or a dictionary"
871+
) as e:
872+
with mlflow.start_run() as run:
873+
# Log the incorrect data format as a table
874+
mlflow.log_table(data="incorrect-data-format", artifact_file=artifact_file)
875+
assert e.value.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE)
876+
877+
with mlflow.start_run() as run:
878+
# Log the dictionary as a table
879+
mlflow.log_table(data=table_dict, artifact_file=artifact_file)
880+
run_id = run.info.run_id
881+
882+
run = mlflow.get_run(run_id)
883+
artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
884+
table_data = pd.read_json(artifact_path, orient="split")
885+
assert table_data.shape[0] == 2
886+
assert table_data.shape[1] == 3
887+
888+
# Get the current value of the tag
889+
current_tag_value = ast.literal_eval(run.data.tags.get(TAG_NAME, "[]"))
890+
assert {"path": artifact_file, "type": "table"} in current_tag_value
891+
assert len(current_tag_value) == 1
892+
893+
table_df = pd.DataFrame.from_dict(table_dict)
894+
with mlflow.start_run(run_id=run_id):
895+
# Log the dataframe as a table
896+
mlflow.log_table(data=table_df, artifact_file=artifact_file)
897+
898+
run = mlflow.get_run(run_id)
899+
artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
900+
table_data = pd.read_json(artifact_path, orient="split")
901+
assert table_data.shape[0] == 4
902+
assert table_data.shape[1] == 3
903+
# Get the current value of the tag
904+
current_tag_value = ast.literal_eval(run.data.tags.get(TAG_NAME, "[]"))
905+
assert {"path": artifact_file, "type": "table"} in current_tag_value
906+
assert len(current_tag_value) == 1
907+
908+
artifact_file_new = "qabot_eval_results_new.json"
909+
with mlflow.start_run(run_id=run_id):
910+
# Log the dataframe as a table to new artifact file
911+
mlflow.log_table(data=table_df, artifact_file=artifact_file_new)
912+
913+
run = mlflow.get_run(run_id)
914+
artifact_path = mlflow.artifacts.download_artifacts(
915+
run_id=run_id, artifact_path=artifact_file_new
916+
)
917+
table_data = pd.read_json(artifact_path, orient="split")
918+
assert table_data.shape[0] == 2
919+
assert table_data.shape[1] == 3
920+
# Get the current value of the tag
921+
current_tag_value = ast.literal_eval(run.data.tags.get(TAG_NAME, "[]"))
922+
assert {"path": artifact_file_new, "type": "table"} in current_tag_value
923+
assert len(current_tag_value) == 2
924+
925+
926+
@pytest.mark.skipif(
927+
"MLFLOW_SKINNY" in os.environ,
928+
reason="Skinny client does not support the np or pandas dependencies",
929+
)
930+
def test_log_table_with_subdirectory():
931+
import pandas as pd
932+
933+
table_dict = {
934+
"inputs": ["What is MLflow?", "What is Databricks?"],
935+
"outputs": ["MLflow is ...", "Databricks is ..."],
936+
"toxicity": [0.0, 0.0],
937+
}
938+
artifact_file = "dir/foo.json"
939+
TAG_NAME = "mlflow.loggedArtifacts"
940+
run_id = None
941+
942+
with mlflow.start_run() as run:
943+
# Log the dictionary as a table
944+
mlflow.log_table(data=table_dict, artifact_file=artifact_file)
945+
run_id = run.info.run_id
946+
947+
run = mlflow.get_run(run_id)
948+
artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
949+
table_data = pd.read_json(artifact_path, orient="split")
950+
assert table_data.shape[0] == 2
951+
assert table_data.shape[1] == 3
952+
953+
# Get the current value of the tag
954+
current_tag_value = ast.literal_eval(run.data.tags.get(TAG_NAME, "[]"))
955+
assert {"path": artifact_file, "type": "table"} in current_tag_value
956+
assert len(current_tag_value) == 1
957+
958+
table_df = pd.DataFrame.from_dict(table_dict)
959+
with mlflow.start_run(run_id=run_id):
960+
# Log the dataframe as a table
961+
mlflow.log_table(data=table_df, artifact_file=artifact_file)
962+
963+
run = mlflow.get_run(run_id)
964+
artifact_path = mlflow.artifacts.download_artifacts(run_id=run_id, artifact_path=artifact_file)
965+
table_data = pd.read_json(artifact_path, orient="split")
966+
assert table_data.shape[0] == 4
967+
assert table_data.shape[1] == 3
968+
# Get the current value of the tag
969+
current_tag_value = ast.literal_eval(run.data.tags.get(TAG_NAME, "[]"))
970+
assert {"path": artifact_file, "type": "table"} in current_tag_value
971+
assert len(current_tag_value) == 1

0 commit comments

Comments
 (0)