Skip to content

Tracking cache requests #1566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 83 additions & 14 deletions fsspec/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,13 @@ class BaseCache:

def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
self.blocksize = blocksize
self.nblocks = 0
self.fetcher = fetcher
self.size = size
self.hit_count = 0
self.miss_count = 0
# the bytes that we actually requested
self.total_requested_bytes = 0

def _fetch(self, start: int | None, stop: int | None) -> bytes:
if start is None:
Expand All @@ -68,6 +73,36 @@ def _fetch(self, start: int | None, stop: int | None) -> bytes:
return b""
return self.fetcher(start, stop)

def _reset_stats(self) -> None:
"""Reset hit and miss counts for a more ganular report e.g. by file."""
self.hit_count = 0
self.miss_count = 0
self.total_requested_bytes = 0

def _log_stats(self) -> str:
"""Return a formatted string of the cache statistics."""
if self.hit_count == 0 and self.miss_count == 0:
# a cache that does nothing, this is for logs only
return ""
return " , %s: %d hits, %d misses, %d total requested bytes" % (
self.name,
self.hit_count,
self.miss_count,
self.total_requested_bytes,
)

def __repr__(self) -> str:
# TODO: use rich for better formatting
return f"""
<{self.__class__.__name__}:
block size : {self.blocksize}
block count : {self.nblocks}
file size : {self.size}
cache hits : {self.hit_count}
cache misses: {self.miss_count}
total requested bytes: {self.total_requested_bytes}>
"""


class MMapCache(BaseCache):
"""memory-mapped sparse file cache
Expand Down Expand Up @@ -126,13 +161,18 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
start_block = start // self.blocksize
end_block = end // self.blocksize
need = [i for i in range(start_block, end_block + 1) if i not in self.blocks]
hits = [i for i in range(start_block, end_block + 1) if i in self.blocks]
self.miss_count += len(need)
self.hit_count += len(hits)
while need:
# TODO: not a for loop so we can consolidate blocks later to
# make fewer fetch calls; this could be parallel
i = need.pop(0)

sstart = i * self.blocksize
send = min(sstart + self.blocksize, self.size)
logger.debug(f"MMap get block #{i} ({sstart}-{send}")
self.total_requested_bytes += send - sstart
logger.debug(f"MMap get block #{i} ({sstart}-{send})")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few of these (not your code, but shows up in the diff due to small corrections). Debug of the form ("test %s", arg) is preferred, so that the string doesn't need to be evaluated in the case that debug is not required.

self.cache[sstart:send] = self.fetcher(sstart, send)
self.blocks.add(i)

Expand Down Expand Up @@ -176,16 +216,20 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
l = end - start
if start >= self.start and end <= self.end:
# cache hit
self.hit_count += 1
return self.cache[start - self.start : end - self.start]
elif self.start <= start < self.end:
# partial hit
self.miss_count += 1
part = self.cache[start - self.start :]
l -= len(part)
start = self.end
else:
# miss
self.miss_count += 1
part = b""
end = min(self.size, end + self.blocksize)
self.total_requested_bytes += end - start
self.cache = self.fetcher(start, end) # new block replaces old
self.start = start
self.end = self.start + len(self.cache)
Expand All @@ -202,24 +246,39 @@ class FirstChunkCache(BaseCache):
name = "first"

def __init__(self, blocksize: int, fetcher: Fetcher, size: int) -> None:
if blocksize > size:
# this will buffer the whole thing
blocksize = size
super().__init__(blocksize, fetcher, size)
self.cache: bytes | None = None

def _fetch(self, start: int | None, end: int | None) -> bytes:
start = start or 0
end = end or self.size
if start > self.size:
logger.debug("FirstChunkCache: requested start > file size")
return b""

end = min(end, self.size)

if start < self.blocksize:
if self.cache is None:
self.miss_count += 1
if end > self.blocksize:
self.total_requested_bytes += end
data = self.fetcher(0, end)
self.cache = data[: self.blocksize]
return data[start:]
self.cache = self.fetcher(0, self.blocksize)
self.total_requested_bytes += self.blocksize
part = self.cache[start:end]
if end > self.blocksize:
self.total_requested_bytes += end - self.blocksize
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment: isn't the number of bytes returned by fetcher the more important number? If we request 100bytes, we might only get 10 back; probably the calling code will request more. OTOH, the various implementations of the request code might keep making new requests until enough bytes arrive.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above, perhaps we are missing another variable to keep track of the total_bytes_returned?

part += self.fetcher(self.blocksize, end)
self.hit_count += 1
return part
else:
self.miss_count += 1
self.total_requested_bytes += end - start
return self.fetcher(start, end)


Expand Down Expand Up @@ -256,12 +315,6 @@ def __init__(
self.maxblocks = maxblocks
self._fetch_block_cached = functools.lru_cache(maxblocks)(self._fetch_block)

def __repr__(self) -> str:
return (
f"<BlockCache blocksize={self.blocksize}, "
f"size={self.size}, nblocks={self.nblocks}>"
)

def cache_info(self):
"""
The statistics on the block cache.
Expand Down Expand Up @@ -319,6 +372,8 @@ def _fetch_block(self, block_number: int) -> bytes:

start = block_number * self.blocksize
end = start + self.blocksize
self.total_requested_bytes += end - start
self.miss_count += 1
logger.info("BlockCache fetching block %d", block_number)
block_contents = super()._fetch(start, end)
return block_contents
Expand All @@ -339,6 +394,7 @@ def _read_cache(
start_pos = start % self.blocksize
end_pos = end % self.blocksize

self.hit_count += 1
if start_block_number == end_block_number:
block: bytes = self._fetch_block_cached(start_block_number)
return block[start_pos:end_pos]
Expand Down Expand Up @@ -404,6 +460,7 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
):
# cache hit: we have all the required data
offset = start - self.start
self.hit_count += 1
return self.cache[offset : offset + end - start]

if self.blocksize:
Expand All @@ -418,27 +475,34 @@ def _fetch(self, start: int | None, end: int | None) -> bytes:
self.end is None or end > self.end
):
# First read, or extending both before and after
self.total_requested_bytes += bend - start
self.miss_count += 1
self.cache = self.fetcher(start, bend)
self.start = start
else:
assert self.start is not None
assert self.end is not None
self.miss_count += 1

if start < self.start:
if self.end is None or self.end - end > self.blocksize:
self.total_requested_bytes += bend - start
self.cache = self.fetcher(start, bend)
self.start = start
else:
self.total_requested_bytes += self.start - start
new = self.fetcher(start, self.start)
self.start = start
self.cache = new + self.cache
elif self.end is not None and bend > self.end:
if self.end > self.size:
pass
elif end - self.end > self.blocksize:
self.total_requested_bytes += bend - start
self.cache = self.fetcher(start, bend)
self.start = start
else:
self.total_requested_bytes += bend - self.end
new = self.fetcher(self.end, bend)
self.cache = self.cache + new

Expand Down Expand Up @@ -470,10 +534,13 @@ def __init__(
) -> None:
super().__init__(blocksize, fetcher, size) # type: ignore[arg-type]
if data is None:
self.miss_count += 1
self.total_requested_bytes += self.size
data = self.fetcher(0, self.size)
self.data = data

def _fetch(self, start: int | None, stop: int | None) -> bytes:
self.hit_count += 1
return self.data[start:stop]


Expand Down Expand Up @@ -551,6 +618,7 @@ def _fetch(self, start: int | None, stop: int | None) -> bytes:
# are allowed to pad reads beyond the
# buffer with zero
out += b"\x00" * (stop - start - len(out))
self.hit_count += 1
return out
else:
# The request ends outside a known range,
Expand All @@ -572,6 +640,8 @@ def _fetch(self, start: int | None, stop: int | None) -> bytes:
f"IO/caching performance may be poor!"
)
logger.debug(f"KnownPartsOfAFile cache fetching {start}-{stop}")
self.total_requested_bytes += stop - start
self.miss_count += 1
return out + super()._fetch(start, stop)


Expand Down Expand Up @@ -676,12 +746,6 @@ def __init__(
self._fetch_future: Future[bytes] | None = None
self._fetch_future_lock = threading.Lock()

def __repr__(self) -> str:
return (
f"<BackgroundBlockCache blocksize={self.blocksize}, "
f"size={self.size}, nblocks={self.nblocks}>"
)

def cache_info(self) -> UpdatableLRU.CacheInfo:
"""
The statistics on the block cache.
Expand Down Expand Up @@ -799,6 +863,8 @@ def _fetch_block(self, block_number: int, log_info: str = "sync") -> bytes:
start = block_number * self.blocksize
end = start + self.blocksize
logger.info("BlockCache fetching block (%s) %d", log_info, block_number)
self.total_requested_bytes += end - start
self.miss_count += 1
block_contents = super()._fetch(start, end)
return block_contents

Expand All @@ -818,6 +884,9 @@ def _read_cache(
start_pos = start % self.blocksize
end_pos = end % self.blocksize

# kind of pointless to count this as a hit, but it is
self.hit_count += 1

if start_block_number == end_block_number:
block = self._fetch_block_cached(start_block_number)
return block[start_pos:end_pos]
Expand Down
9 changes: 8 additions & 1 deletion fsspec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1841,11 +1841,18 @@ def read(self, length=-1):
length = self.size - self.loc
if self.closed:
raise ValueError("I/O operation on closed file.")
logger.debug("%s read: %i - %i", self, self.loc, self.loc + length)
if length == 0:
# don't even bother calling fetch
return b""
out = self.cache._fetch(self.loc, self.loc + length)

logger.debug(
"%s read: %i - %i %s",
self,
self.loc,
self.loc + length,
self.cache._log_stats(),
)
self.loc += len(out)
return out

Expand Down
Loading