Skip to content

Commit 46ccf1d

Browse files
authored
Merge pull request JuliaLogging#129 from JamieMair/add-hparams-api
Add hparams API
2 parents 881a2fd + e4b5ff9 commit 46ccf1d

File tree

13 files changed

+389
-10
lines changed

13 files changed

+389
-10
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,4 @@ test/test_logs
66
docs/Manifest.toml
77

88
gen/proto
9-
gen/protojl
9+
gen/protojl

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,16 @@ StatsBase = "0.27, 0.28, 0.29, 0.30, 0.31, 0.32, 0.33, 0.34"
2020
julia = "1.6"
2121

2222
[extras]
23+
Cairo = "159f3aea-2a34-519c-b102-8c37f9878175"
24+
Fontconfig = "186bb1d3-e1f7-5a2c-a377-96d770f13627"
25+
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
2326
ImageMagick = "6218d12a-5da1-5696-b52f-db25d2ecc6d1"
2427
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
2528
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
2629
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
2730
Minio = "4281f0d9-7ae0-406e-9172-b7277c1efa20"
2831
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2932
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
30-
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
31-
Cairo="159f3aea-2a34-519c-b102-8c37f9878175"
32-
Fontconfig="186bb1d3-e1f7-5a2c-a377-96d770f13627"
3333
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3434
TestImages = "5e47fb64-e119-507b-a336-dd2b206d9990"
3535
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

docs/make.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@ makedocs(
1010
"Backends" => "custom_behaviour.md",
1111
"Reading back data" => "deserialization.md",
1212
"Extending" => "extending_behaviour.md",
13-
"Explicit Interface" => "explicit_interface.md"
13+
"Explicit Interface" => "explicit_interface.md",
14+
"Hyperparameter logging" => "hyperparameters.md"
1415
],
1516
"Examples" => Any[
1617
"Flux.jl" => "examples/flux.md"
1718
"Optim.jl" => "examples/optim.md"
19+
"Hyperparameter tuning" => "examples/hyperparameter_tuning.md"
1820
]
1921
],
2022
format = Documenter.HTML(
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Hyperparameter tuning
2+
3+
We will start this example by setting up a simple random walk experiment, and seeing the effect of the hyperparameter `bias` on the results.
4+
5+
First, import the packages we will need with:
6+
```julia
7+
using TensorBoardLogger, Logging
8+
using Random
9+
```
10+
Next, we will create a function which runs the experiment and logs the results, include the hyperparameters stored in the `config` dictionary.
11+
```julia
12+
function run_experiment(id, config)
13+
logger = TBLogger("random_walk/run$id", tb_append)
14+
15+
# Specify all the metrics we want to track in a list
16+
metric_names = ["scalar/position"]
17+
write_hparams!(logger, config, metric_names)
18+
19+
epochs = config["epochs"]
20+
sigma = config["sigma"]
21+
bias = config["bias"]
22+
with_logger(logger) do
23+
x = 0.0
24+
for i in 1:epochs
25+
x += sigma * randn() + bias
26+
@info "scalar" position = x
27+
end
28+
end
29+
nothing
30+
end
31+
```
32+
Now we can write a script which runs an experiment over a set of parameter values.
33+
```julia
34+
id = 0
35+
for bias in LinRange(-0.1, 0.1, 11)
36+
for epochs in [50, 100]
37+
config = Dict(
38+
"bias"=>bias,
39+
"epochs"=>epochs,
40+
"sigma"=>0.1
41+
)
42+
run_experiment(id, config)
43+
id += 1
44+
end
45+
end
46+
```
47+
48+
Below is an example of the dashboard you get when you open Tensorboard with the command:
49+
```sh
50+
tensorboard --logdir=random_walk
51+
```
52+
53+
![tuning plot](tuning.png)

docs/src/examples/tuning.png

269 KB
Loading

docs/src/hyperparameters.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Hyperparameter logging
2+
3+
In additition to logging the experiments, you may wish to also visualise the effect of hyperparameters on some plotted metrics. This can be done by logging the hyperparameters via the `write_hparams!` function, which takes a dictionary mapping hyperparameter names to their values (currently limited to `Real`, `Bool` or `String` types), along with the names of any metrics that you want to view the effects of.
4+
5+
You can see how the HParams dashboard in Tensorboard can be used to tune hyperparameters on the [tensorboard website](https://www.tensorflow.org/tensorboard/hyperparameter_tuning_with_hparams).
6+
7+
## API
8+
```@docs
9+
write_hparams!
10+
```

docs/src/index.md

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,9 @@ We also support logging custom types from a the following third-party libraries:
117117

118118
## Explicit logging
119119

120-
In alternative, you can also log data to TensorBoard through its functional interface,
121-
by calling the relevant method with a tag string and the data. For information
122-
on this interface refer to [Explicit interface](@ref)...
120+
As an alternative, you can also log data to TensorBoard through its functional interface, by calling the relevant method with a tag string and the data. For information on this interface refer to [Explicit interface](@ref).
121+
122+
## Hyperparameter tuning
123+
124+
Many experiments rely on hyperparameters, which can be difficult to tune. Tensorboard allows you to visualise the effect of your hyperparameters on your metrics, giving you an intuition for the correct hyperparameters for your task. For information on this API, see the [Hyperparameter logging](@ref) manual page.
125+

examples/HParams.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
using TensorBoardLogger #import the TensorBoardLogger package
2+
using Logging #import Logging package
3+
using Random # Exports randn
4+
5+
# Run 10 experiments to see a plot
6+
for j in 1:10
7+
logger = TBLogger("random_walks/run$j", tb_append)
8+
9+
sigma = 0.1
10+
epochs = 200
11+
bias = (rand()*2 - 1) / 10 # create a random bias
12+
use_seed = false
13+
# Add in the a dummy loss metric
14+
with_logger(logger) do
15+
x = 0.0
16+
for i in 1:epochs
17+
x += sigma * randn() + bias
18+
@info "scalar" loss = x
19+
end
20+
end
21+
22+
# Hyperparameter is a dictionary of parameter names to their values. This
23+
# supports numerical types, bools and strings. Non-bool numerical types
24+
# are converted to Float64 to be displayed.
25+
hparams_config = Dict{String, Any}(
26+
"sigma"=>sigma,
27+
"epochs"=>epochs,
28+
"bias"=>bias,
29+
"use_seed"=>use_seed,
30+
"method"=>"MC"
31+
)
32+
# Specify a list of tags that you want to show up in the hyperparameter
33+
# comparison
34+
metrics = ["scalar/loss"]
35+
36+
# Write the hyperparameters and metrics config to the logger.
37+
write_hparams!(logger, hparams_config, metrics)
38+
end

gen/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@ FilePathsBase = "48062228-2e41-5def-b9a4-89aafe57970f"
44
Glob = "c27321d9-0574-5035-807b-f59d2c89b15c"
55
ProtoBuf = "3349acd9-ac6a-5e09-bcdb-63829b23a429"
66

7-
[comapt]
7+
[compat]
88
ProtoBuf = "0.9.1"

src/TensorBoardLogger.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using Base.CoreLogging: CoreLogging, AbstractLogger, LogLevel, Info,
2121

2222
export TBLogger, reset!, set_step!, increment_step!, set_step_increment!, with_TBLogger_hold_step
2323
export log_histogram, log_value, log_vector, log_text, log_image, log_images,
24-
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar
24+
log_audio, log_audios, log_graph, log_embeddings, log_custom_scalar, write_hparams!
2525
export map_summaries, TBReader
2626

2727
export ImageFormat, L, CL, LC, LN, NL, NCL, NLC, CLN, LCN, HW, WH, HWC, WHC,
@@ -62,6 +62,7 @@ include("ImageFormat.jl")
6262
const TB_PLUGIN_JLARRAY_NAME = "_jl_tbl_array_sz"
6363

6464
include("TBLogger.jl")
65+
include("hparams.jl")
6566
include("utils.jl") # CRC Utils
6667
include("event.jl")
6768
include("Loggers/base.jl")

0 commit comments

Comments
 (0)