@@ -47,21 +47,19 @@ def __enter__(self):
47
47
def __exit__ (self , exc_type , exc_value , exc_tb ):
48
48
self .delete ()
49
49
50
- def _to_dict (self ) -> protos .CachedContent :
50
+ def _to_dict (self , ** input_only_update_fields ) -> protos .CachedContent :
51
51
proto_paths = {
52
52
"name" : self .name ,
53
- "model" : self .model ,
54
53
}
54
+ proto_paths .update (input_only_update_fields )
55
55
return protos .CachedContent (** proto_paths )
56
56
57
57
def _apply_update (self , path , value ):
58
58
parts = path .split ("." )
59
59
for part in parts [:- 1 ]:
60
60
self = getattr (self , part )
61
- if parts [- 1 ] == "ttl" :
62
- value = self .expire_time + datetime .timedelta (seconds = value ["seconds" ])
63
- parts [- 1 ] = "expire_time"
64
- setattr (self , parts [- 1 ], value )
61
+ if path [- 1 ] != "ttl" :
62
+ setattr (self , parts [- 1 ], value )
65
63
66
64
@classmethod
67
65
def _decode_cached_content (cls , cached_content : protos .CachedContent ) -> CachedContent :
@@ -112,7 +110,7 @@ def _prepare_create_request(
112
110
contents = content_types .to_contents (contents )
113
111
114
112
if ttl :
115
- ttl = caching_types .to_ttl (ttl )
113
+ ttl = caching_types .to_expiration (ttl )
116
114
117
115
cached_content = protos .CachedContent (
118
116
name = name ,
@@ -236,25 +234,35 @@ def update(
236
234
if client is None :
237
235
client = get_default_cache_client ()
238
236
237
+ if "ttl" in updates and "expire_time" in updates :
238
+ raise ValueError (
239
+ "`expiration` is a _oneof field. Please provide either `ttl` or `expire_time`."
240
+ )
241
+
242
+ field_mask = field_mask_pb2 .FieldMask ()
243
+
239
244
updates = flatten_update_paths (updates )
240
245
for update_path in updates :
241
- if update_path == "ttl" :
246
+ if update_path == "ttl" or update_path == "expire_time" :
242
247
updates = updates .copy ()
243
248
update_path_val = updates .get (update_path )
244
- updates [update_path ] = caching_types .to_ttl (update_path_val )
249
+ updates [update_path ] = caching_types .to_expiration (update_path_val )
245
250
else :
246
251
raise ValueError (
247
252
f"As of now, only `ttl` can be updated for `CachedContent`. Got: `{ update_path } ` instead."
248
253
)
249
- field_mask = field_mask_pb2 .FieldMask ()
250
254
251
- for path in updates . keys ():
252
- field_mask . paths . append ( path )
255
+ field_mask . paths . append ( update_path )
256
+
253
257
for path , value in updates .items ():
254
258
self ._apply_update (path , value )
255
259
256
260
request = protos .UpdateCachedContentRequest (
257
- cached_content = self ._to_dict (), update_mask = field_mask
261
+ cached_content = self ._to_dict (** updates ), update_mask = field_mask
258
262
)
259
- client .update_cached_content (request )
263
+ updated_cc = client .update_cached_content (request )
264
+ updated_cc = self ._decode_cached_content (updated_cc )
265
+ for path , value in dataclasses .asdict (updated_cc ).items ():
266
+ self ._apply_update (path , value )
267
+
260
268
return self
0 commit comments