Skip to content

Commit 08c89cd

Browse files
committed
Implemented a new encoder for the hyperparameter dict - needs cleaning
1 parent 627db33 commit 08c89cd

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/hparams.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,15 +82,24 @@ end
8282

8383
# Overload the dictionary encoder
8484
function PB.encode(e::ProtoEncoder, i::Int, x::Dict{String,HValue})
85-
PB.Codecs.encode_tag(e, i, PB.Codecs.LENGTH_DELIMITED)
86-
PB.Codecs.vbyte_encode(e.io, UInt32(PB.Codecs._encoded_size(x)))
87-
85+
# PB.Codecs.encode_tag(e, i, PB.Codecs.LENGTH_DELIMITED)
86+
# remaining_size = PB.Codecs._encoded_size(x, i) - 2 # remove two for the field name and length
87+
# PB.Codecs.vbyte_encode(e.io, UInt32(remaining_size))
88+
8889
for (k, v) in x
90+
PB.Codecs.encode_tag(e, 1, PB.Codecs.LENGTH_DELIMITED)
91+
total_size = PB.Codecs._encoded_size(k, 1) + PB.Codecs._encoded_size(v, 2)
92+
PB.Codecs.vbyte_encode(e.io, UInt32(total_size)) # Add two for the wire type and length
8993
PB.Codecs.encode(e, 1, k)
9094
PB.Codecs.encode(e, 2, v)
9195
end
9296
return nothing
9397
end
98+
function PB.Codecs._encoded_size(x::Dict{String,HValue}, i::Int)
99+
# Field number and length is another 2 bytes
100+
# There are two bytes for each key value pair extra
101+
return mapreduce((xi) -> 2 + PB.Codecs._encoded_size(xi.first, 1) + PB.Codecs._encoded_size(xi.second, 2), +, x, init=0)
102+
end
94103

95104

96105
"""

0 commit comments

Comments
 (0)