Skip to content

Commit 493279c

Browse files
authored
perf(data status): batch remote checks (#10792)
1 parent 95c5ea8 commit 493279c

File tree

1 file changed

+127
-62
lines changed

1 file changed

+127
-62
lines changed

dvc/repo/data.py

Lines changed: 127 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,25 @@
11
import os
22
import posixpath
3-
from collections.abc import Iterable
3+
from collections import defaultdict
4+
from collections.abc import Iterable, Iterator, Mapping
45
from typing import TYPE_CHECKING, Optional, TypedDict, Union
56

6-
from dvc.fs.callbacks import DEFAULT_CALLBACK
7+
from dvc.fs.callbacks import DEFAULT_CALLBACK, Callback, TqdmCallback
78
from dvc.log import logger
8-
from dvc.scm import RevError
99
from dvc.ui import ui
10-
from dvc_data.index.view import DataIndexView
1110

1211
if TYPE_CHECKING:
13-
from dvc.fs.callbacks import Callback
1412
from dvc.repo import Repo
1513
from dvc.scm import Git, NoSCM
16-
from dvc_data.index import BaseDataIndex, DataIndex, DataIndexKey
14+
from dvc_data.index import (
15+
BaseDataIndex,
16+
DataIndex,
17+
DataIndexEntry,
18+
DataIndexKey,
19+
DataIndexView,
20+
)
1721
from dvc_data.index.diff import Change
22+
from dvc_objects.fs.base import FileSystem
1823

1924
logger = logger.getChild(__name__)
2025

@@ -23,21 +28,6 @@ def posixpath_to_os_path(path: str) -> str:
2328
return path.replace(posixpath.sep, os.path.sep)
2429

2530

26-
def _adapt_typ(typ: str) -> str:
27-
from dvc_data.index.diff import ADD, DELETE, MODIFY
28-
29-
if typ == MODIFY:
30-
return "modified"
31-
32-
if typ == ADD:
33-
return "added"
34-
35-
if typ == DELETE:
36-
return "deleted"
37-
38-
return typ
39-
40-
4131
def _adapt_path(change: "Change") -> str:
4232
isdir = False
4333
if change.new and change.new.meta:
@@ -50,25 +40,65 @@ def _adapt_path(change: "Change") -> str:
5040
return os.path.sep.join(key)
5141

5242

43+
def _get_missing_paths(
44+
to_check: Mapping["FileSystem", Mapping[str, Iterable["DataIndexEntry"]]],
45+
batch_size: Optional[int] = None,
46+
callback: "Callback" = DEFAULT_CALLBACK,
47+
) -> Iterator[str]:
48+
for fs, paths_map in to_check.items():
49+
if batch_size == 1 or (batch_size is None and fs.protocol == "local"):
50+
results = list(callback.wrap(map(fs.exists, paths_map)))
51+
else:
52+
results = fs.exists(
53+
list(paths_map), batch_size=batch_size, callback=callback
54+
)
55+
56+
for cache_path, exists in zip(paths_map, results):
57+
if exists:
58+
continue
59+
60+
for entry in paths_map[cache_path]:
61+
key = entry.key
62+
assert key
63+
if entry.meta and entry.meta.isdir:
64+
key = (*key, "")
65+
yield os.path.sep.join(key)
66+
67+
68+
class StorageCallback(Callback):
69+
def __init__(self, parent_cb: Callback) -> None:
70+
super().__init__(size=0, value=0)
71+
self.parent_cb = parent_cb
72+
73+
def set_size(self, size: int) -> None:
74+
# This is a no-op to prevent `fs.exists` from trying to set the size
75+
pass
76+
77+
def relative_update(self, value: int = 1) -> None:
78+
self.parent_cb.relative_update(value)
79+
80+
def absolute_update(self, value: int) -> None:
81+
self.parent_cb.relative_update(value - self.value)
82+
83+
5384
def _diff(
5485
old: "BaseDataIndex",
5586
new: "BaseDataIndex",
5687
*,
88+
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
5789
granular: bool = False,
5890
not_in_cache: bool = False,
91+
batch_size: Optional[int] = None,
5992
callback: "Callback" = DEFAULT_CALLBACK,
60-
filter_keys: Optional[list["DataIndexKey"]] = None,
6193
) -> dict[str, list[str]]:
62-
from dvc_data.index.diff import UNCHANGED, UNKNOWN, diff
94+
from dvc_data.index.diff import ADD, DELETE, MODIFY, UNCHANGED, UNKNOWN, diff
6395

64-
ret: dict[str, list[str]] = {}
96+
ret: dict[str, list[str]] = defaultdict(list)
97+
change_types = {MODIFY: "modified", ADD: "added", DELETE: "deleted"}
6598

66-
def _add_change(typ, change):
67-
typ = _adapt_typ(typ)
68-
if typ not in ret:
69-
ret[typ] = []
70-
71-
ret[typ].append(_adapt_path(change))
99+
to_check: dict[FileSystem, dict[str, list[DataIndexEntry]]] = defaultdict(
100+
lambda: defaultdict(list)
101+
)
72102

73103
for change in diff(
74104
old,
@@ -84,9 +114,7 @@ def _add_change(typ, change):
84114
# still appear in the view. As a result, keys like `dir/` will be present
85115
# even if only `dir/file` matches the filter.
86116
# We need to skip such entries to avoid showing root of tracked directories.
87-
if filter_keys and not any(
88-
change.key[: len(filter_key)] == filter_key for filter_key in filter_keys
89-
):
117+
if filter_keys and not any(change.key[: len(fk)] == fk for fk in filter_keys):
90118
continue
91119

92120
if (
@@ -101,18 +129,27 @@ def _add_change(typ, change):
101129
# NOTE: emulating previous behaviour
102130
continue
103131

104-
if (
105-
not_in_cache
106-
and change.old
107-
and change.old.hash_info
108-
and not old.storage_map.cache_exists(change.old)
109-
):
110-
# NOTE: emulating previous behaviour
111-
_add_change("not_in_cache", change)
132+
if not_in_cache and change.old and change.old.hash_info:
133+
old_entry = change.old
134+
cache_fs, cache_path = old.storage_map.get_cache(old_entry)
135+
# check later in batches
136+
to_check[cache_fs][cache_path].append(old_entry)
112137

113-
_add_change(change.typ, change)
138+
change_typ = change_types.get(change.typ, change.typ)
139+
ret[change_typ].append(_adapt_path(change))
114140

115-
return ret
141+
total_items = sum(
142+
len(entries) for paths in to_check.values() for entries in paths.values()
143+
)
144+
with TqdmCallback(size=total_items, desc="Checking cache", unit="entry") as cb:
145+
missing_items = list(
146+
_get_missing_paths(
147+
to_check, batch_size=batch_size, callback=StorageCallback(cb)
148+
),
149+
)
150+
if missing_items:
151+
ret["not_in_cache"] = missing_items
152+
return dict(ret)
116153

117154

118155
class GitInfo(TypedDict, total=False):
@@ -153,8 +190,10 @@ def _git_info(scm: Union["Git", "NoSCM"], untracked_files: str = "all") -> GitIn
153190

154191
def filter_index(
155192
index: Union["DataIndex", "DataIndexView"],
156-
filter_keys: Optional[list["DataIndexKey"]] = None,
193+
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
157194
) -> "BaseDataIndex":
195+
from dvc_data.index.view import DataIndexView
196+
158197
if not filter_keys:
159198
return index
160199

@@ -187,8 +226,9 @@ def filter_fn(key: "DataIndexKey") -> bool:
187226

188227
def _diff_index_to_wtree(
189228
repo: "Repo",
190-
filter_keys: Optional[list["DataIndexKey"]] = None,
229+
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
191230
granular: bool = False,
231+
batch_size: Optional[int] = None,
192232
) -> dict[str, list[str]]:
193233
from .index import build_data_index
194234

@@ -214,16 +254,18 @@ def _diff_index_to_wtree(
214254
filter_keys=filter_keys,
215255
granular=granular,
216256
not_in_cache=True,
257+
batch_size=batch_size,
217258
callback=pb.as_callback(),
218259
)
219260

220261

221262
def _diff_head_to_index(
222263
repo: "Repo",
223264
head: str = "HEAD",
224-
filter_keys: Optional[list["DataIndexKey"]] = None,
265+
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
225266
granular: bool = False,
226267
) -> dict[str, list[str]]:
268+
from dvc.scm import RevError
227269
from dvc_data.index import DataIndex
228270

229271
index = repo.index.data["repo"]
@@ -278,9 +320,10 @@ def _transform_git_paths_to_dvc(repo: "Repo", files: Iterable[str]) -> list[str]
278320

279321
def _get_entries_not_in_remote(
280322
repo: "Repo",
281-
filter_keys: Optional[list["DataIndexKey"]] = None,
323+
filter_keys: Optional[Iterable["DataIndexKey"]] = None,
282324
granular: bool = False,
283325
remote_refresh: bool = False,
326+
batch_size: Optional[int] = None,
284327
) -> list[str]:
285328
"""Get entries that are not in remote storage."""
286329
from dvc.repo.worktree import worktree_view
@@ -293,7 +336,13 @@ def _get_entries_not_in_remote(
293336
view = filter_index(data_index, filter_keys=filter_keys) # type: ignore[arg-type]
294337

295338
missing_entries = []
296-
with ui.progress(desc="Checking remote", unit="entry") as pb:
339+
340+
to_check: dict[FileSystem, dict[str, list[DataIndexEntry]]] = defaultdict(
341+
lambda: defaultdict(list)
342+
)
343+
344+
storage_map = view.storage_map
345+
with TqdmCallback(size=0, desc="Checking remote", unit="entry") as cb:
297346
for key, entry in view.iteritems(shallow=not granular):
298347
if not (entry and entry.hash_info):
299348
continue
@@ -309,13 +358,28 @@ def _get_entries_not_in_remote(
309358
continue
310359

311360
k = (*key, "") if entry.meta and entry.meta.isdir else key
312-
try:
313-
if not view.storage_map.remote_exists(entry, refresh=remote_refresh):
314-
missing_entries.append(os.path.sep.join(k))
315-
pb.update()
316-
except StorageKeyError:
317-
pass
318-
361+
if remote_refresh:
362+
# on remote_refresh, collect all entries to check
363+
# then check them in batches below
364+
try:
365+
remote_fs, remote_path = storage_map.get_remote(entry)
366+
to_check[remote_fs][remote_path].append(entry)
367+
cb.size += 1
368+
cb.relative_update(0) # try to update the progress bar
369+
except StorageKeyError:
370+
pass
371+
else:
372+
try:
373+
if not storage_map.remote_exists(entry, refresh=remote_refresh):
374+
missing_entries.append(os.path.sep.join(k))
375+
cb.relative_update() # no need to update the size
376+
except StorageKeyError:
377+
pass
378+
missing_entries.extend(
379+
_get_missing_paths(
380+
to_check, batch_size=batch_size, callback=StorageCallback(cb)
381+
)
382+
)
319383
return missing_entries
320384

321385

@@ -324,7 +388,7 @@ def _matches_target(p: str, targets: Iterable[str]) -> bool:
324388
return any(p == t or p.startswith(t + sep) for t in targets)
325389

326390

327-
def _prune_keys(filter_keys: list["DataIndexKey"]) -> list["DataIndexKey"]:
391+
def _prune_keys(filter_keys: Iterable["DataIndexKey"]) -> list["DataIndexKey"]:
328392
sorted_keys = sorted(set(filter_keys), key=len)
329393
result: list[DataIndexKey] = []
330394

@@ -342,6 +406,7 @@ def status(
342406
untracked_files: str = "no",
343407
not_in_remote: bool = False,
344408
remote_refresh: bool = False,
409+
batch_size: Optional[int] = None,
345410
head: str = "HEAD",
346411
) -> Status:
347412
from dvc.scm import NoSCMError, SCMError
@@ -352,19 +417,19 @@ def status(
352417
filter_keys = _prune_keys(filter_keys)
353418

354419
uncommitted_diff = _diff_index_to_wtree(
355-
repo, filter_keys=filter_keys, granular=granular
420+
repo, filter_keys=filter_keys, granular=granular, batch_size=batch_size
356421
)
357422
unchanged = set(uncommitted_diff.pop("unchanged", []))
358-
entries_not_in_remote = (
359-
_get_entries_not_in_remote(
423+
424+
entries_not_in_remote: list[str] = []
425+
if not_in_remote:
426+
entries_not_in_remote = _get_entries_not_in_remote(
360427
repo,
361428
filter_keys=filter_keys,
362429
granular=granular,
363430
remote_refresh=remote_refresh,
431+
batch_size=batch_size,
364432
)
365-
if not_in_remote
366-
else []
367-
)
368433

369434
try:
370435
committed_diff = _diff_head_to_index(

0 commit comments

Comments
 (0)