Skip to content

Commit acb3806

Browse files
fix update method
Change-Id: I433c25b2d80cdf6e483b59f61ff29bb8d2dc6595
1 parent fb9995c commit acb3806

File tree

3 files changed

+59
-24
lines changed

3 files changed

+59
-24
lines changed

google/generativeai/caching.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -47,21 +47,19 @@ def __enter__(self):
4747
def __exit__(self, exc_type, exc_value, exc_tb):
4848
self.delete()
4949

50-
def _to_dict(self) -> protos.CachedContent:
50+
def _to_dict(self, **input_only_update_fields) -> protos.CachedContent:
5151
proto_paths = {
5252
"name": self.name,
53-
"model": self.model,
5453
}
54+
proto_paths.update(input_only_update_fields)
5555
return protos.CachedContent(**proto_paths)
5656

5757
def _apply_update(self, path, value):
5858
parts = path.split(".")
5959
for part in parts[:-1]:
6060
self = getattr(self, part)
61-
if parts[-1] == "ttl":
62-
value = self.expire_time + datetime.timedelta(seconds=value["seconds"])
63-
parts[-1] = "expire_time"
64-
setattr(self, parts[-1], value)
61+
if path[-1] != "ttl":
62+
setattr(self, parts[-1], value)
6563

6664
@classmethod
6765
def _decode_cached_content(cls, cached_content: protos.CachedContent) -> CachedContent:
@@ -112,7 +110,7 @@ def _prepare_create_request(
112110
contents = content_types.to_contents(contents)
113111

114112
if ttl:
115-
ttl = caching_types.to_ttl(ttl)
113+
ttl = caching_types.to_expiration(ttl)
116114

117115
cached_content = protos.CachedContent(
118116
name=name,
@@ -236,25 +234,35 @@ def update(
236234
if client is None:
237235
client = get_default_cache_client()
238236

237+
if "ttl" in updates and "expire_time" in updates:
238+
raise ValueError(
239+
"`expiration` is a _oneof field. Please provide either `ttl` or `expire_time`."
240+
)
241+
242+
field_mask = field_mask_pb2.FieldMask()
243+
239244
updates = flatten_update_paths(updates)
240245
for update_path in updates:
241-
if update_path == "ttl":
246+
if update_path == "ttl" or update_path == "expire_time":
242247
updates = updates.copy()
243248
update_path_val = updates.get(update_path)
244-
updates[update_path] = caching_types.to_ttl(update_path_val)
249+
updates[update_path] = caching_types.to_expiration(update_path_val)
245250
else:
246251
raise ValueError(
247252
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{update_path}` instead."
248253
)
249-
field_mask = field_mask_pb2.FieldMask()
250254

251-
for path in updates.keys():
252-
field_mask.paths.append(path)
255+
field_mask.paths.append(update_path)
256+
253257
for path, value in updates.items():
254258
self._apply_update(path, value)
255259

256260
request = protos.UpdateCachedContentRequest(
257-
cached_content=self._to_dict(), update_mask=field_mask
261+
cached_content=self._to_dict(**updates), update_mask=field_mask
258262
)
259-
client.update_cached_content(request)
263+
updated_cc = client.update_cached_content(request)
264+
updated_cc = self._decode_cached_content(updated_cc)
265+
for path, value in dataclasses.asdict(updated_cc).items():
266+
self._apply_update(path, value)
267+
260268
return self

google/generativeai/types/caching_types.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from typing_extensions import TypedDict
2020
import re
2121

22-
__all__ = ["TTL"]
22+
__all__ = ["ExpirationTypes", "ExpireTime", "TTL"]
2323

2424

2525
_VALID_CACHED_CONTENT_NAME = r"([a-z0-9-\.]+)$"
@@ -33,18 +33,36 @@ def valid_cached_content_name(name: str) -> bool:
3333

3434

3535
class TTL(TypedDict):
36+
# Represents datetime.datetime.now() + desired ttl
3637
seconds: int
38+
nanos: int = 0
39+
40+
class ExpireTime(TypedDict):
41+
# Represents seconds of UTC time since Unix epoch
42+
seconds: int
43+
nanos: int = 0
3744

3845

39-
ExpirationTypes = Union[TTL, int, datetime.timedelta]
46+
ExpirationTypes = Union[TTL, ExpireTime, int, datetime.timedelta, datetime.datetime]
4047

4148

42-
def to_ttl(expiration: Optional[ExpirationTypes]) -> TTL:
43-
if isinstance(expiration, datetime.timedelta):
44-
return {"seconds": int(expiration.total_seconds())}
49+
def to_expiration(expiration: Optional[ExpirationTypes]) -> TTL:
50+
if isinstance(expiration, datetime.timedelta): # consider `ttl`
51+
return {
52+
"seconds": int(expiration.total_seconds()),
53+
"nanos": int(expiration.microseconds * 1000),
54+
}
55+
elif isinstance(expiration, datetime.datetime): # consider `expire_time`
56+
timestamp = expiration.timestamp()
57+
seconds = int(timestamp)
58+
nanos = int((seconds % 1) * 1000)
59+
return {
60+
"seconds": seconds,
61+
"nanos": nanos,
62+
}
4563
elif isinstance(expiration, dict):
4664
return expiration
47-
elif isinstance(expiration, int):
65+
elif isinstance(expiration, int): # consider `ttl`
4866
return {"seconds": expiration}
4967
else:
5068
raise TypeError(

tests/test_caching.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,10 +210,19 @@ def test_update_cached_content_invalid_update_paths(self):
210210
with self.assertRaises(ValueError):
211211
cc.update(updates=update_masks)
212212

213-
def test_update_cached_content_valid_update_paths(self):
214-
update_masks = dict(
215-
ttl=datetime.timedelta(hours=2),
216-
)
213+
@parameterized.named_parameters(
214+
[
215+
dict(
216+
testcase_name="ttl",
217+
update_masks=dict(ttl=datetime.timedelta(hours=2))
218+
),
219+
dict(
220+
testcase_name="expire_time",
221+
update_masks=dict(expire_time=datetime.datetime(2024, 6, 5, 12, 12, 12, 23))
222+
)
223+
]
224+
)
225+
def test_update_cached_content_valid_update_paths(self, update_masks):
217226

218227
cc = caching.CachedContent.get(name="cachedContents/test-cached-content")
219228
cc = cc.update(updates=update_masks)

0 commit comments

Comments
 (0)