Skip to content

Commit 94f9847

Browse files
alvarosgdiegolascasas
authored andcommitted
Making integration test pull an actual dataset and do a few steps of training and evaluation.
PiperOrigin-RevId: 331524333
1 parent ef672a0 commit 94f9847

File tree

12 files changed

+1594
-0
lines changed

12 files changed

+1594
-0
lines changed

learning_to_simulate/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Learning to Simulate Complex Physics with Graph Networks
2+
3+
This is a model implementation for the ICML 2020 submission (also available in
4+
arXiv [arxiv.org/abs/2002.09405](https://arxiv.org/abs/2002.09405). If you use
5+
the code here please cite this paper:
6+
7+
@article{sanchezgonzalez2020learning,
8+
title={Learning to Simulate Complex Physics with Graph Networks},
9+
author={Alvaro Sanchez-Gonzalez and
10+
Jonathan Godwin and
11+
Tobias Pfaff and
12+
Rex Ying and
13+
Jure Leskovec and
14+
Peter W. Battaglia},
15+
url={https://arxiv.org/abs/2002.09405},
16+
year={2020},
17+
eprint={2002.09405},
18+
archivePrefix={arXiv},
19+
primaryClass={cs.LG}
20+
}
21+
22+
## Contents
23+
24+
* `model.py`: implementation of the graph network use as the learnable part of
25+
the model.
26+
* `model_demo.py`: example connecting the model to input dummy data.
27+
28+
## Running demo
29+
30+
(From one directory above)
31+
32+
pip install -r learning_to_simulate/requirements.txt
33+
python -m learning_to_simulate.model_demo
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
# Lint as: python3
2+
# pylint: disable=g-bad-file-header
3+
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ============================================================================
17+
"""Tools to compute the connectivity of the graph."""
18+
19+
import numpy as np
20+
from sklearn import neighbors
21+
import tensorflow.compat.v1 as tf
22+
23+
24+
def _compute_connectivity(positions, radius):
25+
"""Get the indices of connected edges with radius connectivity.
26+
27+
Args:
28+
positions: Positions of nodes in the graph. Shape:
29+
[num_nodes_in_graph, num_dims].
30+
radius: Radius of connectivity.
31+
32+
Returns:
33+
senders indices [num_edges_in_graph]
34+
receiver indices [num_edges_in_graph]
35+
36+
"""
37+
tree = neighbors.KDTree(positions)
38+
receivers_list = tree.query_radius(positions, r=radius)
39+
num_nodes = len(positions)
40+
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
41+
receivers = np.concatenate(receivers_list, axis=0)
42+
return senders, receivers
43+
44+
45+
def _compute_connectivity_for_batch(positions, n_node, radius):
46+
"""`compute_connectivity` for a batch of graphs.
47+
48+
Args:
49+
positions: Positions of nodes in the batch of graphs. Shape:
50+
[num_nodes_in_batch, num_dims].
51+
n_node: Number of nodes for each graph in the batch. Shape:
52+
[num_graphs in batch].
53+
radius: Radius of connectivity.
54+
55+
Returns:
56+
senders indices [num_edges_in_batch]
57+
receiver indices [num_edges_in_batch]
58+
number of edges per graph [num_graphs_in_batch]
59+
60+
"""
61+
62+
# TODO(alvarosg): Consider if we want to support batches here or not.
63+
# Separate the positions corresponding to particles in different graphs.
64+
positions_per_graph_list = np.split(positions, np.cumsum(n_node[:-1]), axis=0)
65+
receivers_list = []
66+
senders_list = []
67+
n_edge_list = []
68+
num_nodes_in_previous_graphs = 0
69+
70+
# Compute connectivity for each graph in the batch.
71+
for positions_graph_i in positions_per_graph_list:
72+
senders_graph_i, receivers_graph_i = _compute_connectivity(
73+
positions_graph_i, radius)
74+
75+
num_edges_graph_i = len(senders_graph_i)
76+
n_edge_list.append(num_edges_graph_i)
77+
78+
# Because the inputs will be concatenated, we need to add offsets to the
79+
# sender and receiver indices according to the number of nodes in previous
80+
# graphs in the same batch.
81+
receivers_list.append(receivers_graph_i + num_nodes_in_previous_graphs)
82+
senders_list.append(senders_graph_i + num_nodes_in_previous_graphs)
83+
84+
num_nodes_graph_i = len(positions_graph_i)
85+
num_nodes_in_previous_graphs += num_nodes_graph_i
86+
87+
# Concatenate all of the results.
88+
senders = np.concatenate(senders_list, axis=0).astype(np.int32)
89+
receivers = np.concatenate(receivers_list, axis=0).astype(np.int32)
90+
n_edge = np.stack(n_edge_list).astype(np.int32)
91+
92+
return senders, receivers, n_edge
93+
94+
95+
def compute_connectivity_for_batch_pyfunc(positions, n_node, radius):
96+
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
97+
senders, receivers, n_edge = tf.py_function(
98+
_compute_connectivity_for_batch,
99+
[positions, n_node, radius], [tf.int32, tf.int32, tf.int32])
100+
senders.set_shape([None])
101+
receivers.set_shape([None])
102+
n_edge.set_shape(n_node.get_shape())
103+
return senders, receivers, n_edge
104+
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/bin/bash
2+
# Copyright 2020 Deepmind Technologies Limited.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
#
16+
# Usage:
17+
# bash download_dataset.sh ${DATASET_NAME} ${OUTPUT_DIR}
18+
# Example:
19+
# bash download_dataset.sh WaterDrop /tmp/
20+
21+
set -e
22+
23+
DATASET_NAME="${1}"
24+
OUTPUT_DIR="${2}/${DATASET_NAME}"
25+
26+
BASE_URL="https://storage.googleapis.com/learning-to-simulate-complex-physics/Datasets/${DATASET_NAME}/"
27+
28+
mkdir -p ${OUTPUT_DIR}
29+
for file in metadata.json train.tfrecord valid.tfrecord test.tfrecord
30+
do
31+
wget -O "${OUTPUT_DIR}/${file}" "${BASE_URL}${file}"
32+
done
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
# Lint as: python3
2+
# pylint: disable=g-bad-file-header
3+
# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# ============================================================================
17+
18+
"""Graph network implementation accompanying ICML 2020 submission.
19+
20+
"Learning to Simulate Complex Physics with Graph Networks"
21+
22+
Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying,
23+
Jure Leskovec, Peter W. Battaglia
24+
25+
https://arxiv.org/abs/2002.09405
26+
27+
The Sonnet `EncodeProcessDecode` module provided here implements the learnable
28+
parts of the model.
29+
It assumes an encoder preprocessor has already built a graph with
30+
connectivity and features as described in the paper, with features normalized
31+
to zero-mean unit-variance.
32+
33+
Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library.
34+
"""
35+
36+
import graph_nets as gn
37+
import sonnet as snt
38+
import tensorflow as tf
39+
40+
41+
def build_mlp(
42+
hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module:
43+
"""Builds an MLP."""
44+
return snt.nets.MLP(
45+
output_sizes=[hidden_size] * num_hidden_layers + [output_size])
46+
47+
48+
class EncodeProcessDecode(snt.AbstractModule):
49+
"""Encode-Process-Decode function approximator for learnable simulator."""
50+
51+
def __init__(
52+
self,
53+
latent_size: int,
54+
mlp_hidden_size: int,
55+
mlp_num_hidden_layers: int,
56+
num_message_passing_steps: int,
57+
output_size: int,
58+
name: str = "EncodeProcessDecode"):
59+
"""Inits the model.
60+
61+
Args:
62+
latent_size: Size of the node and edge latent representations.
63+
mlp_hidden_size: Hidden layer size for all MLPs.
64+
mlp_num_hidden_layers: Number of hidden layers in all MLPs.
65+
num_message_passing_steps: Number of message passing steps.
66+
output_size: Output size of the decode node representations as required
67+
by the downstream update function.
68+
name: Name of the model.
69+
"""
70+
71+
super().__init__(name=name)
72+
73+
self._latent_size = latent_size
74+
self._mlp_hidden_size = mlp_hidden_size
75+
self._mlp_num_hidden_layers = mlp_num_hidden_layers
76+
self._num_message_passing_steps = num_message_passing_steps
77+
self._output_size = output_size
78+
79+
with self._enter_variable_scope():
80+
self._networks_builder()
81+
82+
def _build(self, input_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
83+
"""Forward pass of the learnable dynamics model."""
84+
85+
# Encode the input_graph.
86+
latent_graph_0 = self._encode(input_graph)
87+
88+
# Do `m` message passing steps in the latent graphs.
89+
latent_graph_m = self._process(latent_graph_0)
90+
91+
# Decode from the last latent graph.
92+
return self._decode(latent_graph_m)
93+
94+
def _networks_builder(self):
95+
"""Builds the networks."""
96+
97+
def build_mlp_with_layer_norm():
98+
mlp = build_mlp(
99+
hidden_size=self._mlp_hidden_size,
100+
num_hidden_layers=self._mlp_num_hidden_layers,
101+
output_size=self._latent_size)
102+
return snt.Sequential([mlp, snt.LayerNorm()])
103+
104+
# The encoder graph network independently encodes edge and node features.
105+
encoder_kwargs = dict(
106+
edge_model_fn=build_mlp_with_layer_norm,
107+
node_model_fn=build_mlp_with_layer_norm)
108+
self._encoder_network = gn.modules.GraphIndependent(**encoder_kwargs)
109+
110+
# Create `num_message_passing_steps` graph networks with unshared parameters
111+
# that update the node and edge latent features.
112+
# Note that we can use `modules.InteractionNetwork` because
113+
# it also outputs the messages as updated edge latent features.
114+
self._processor_networks = []
115+
for _ in range(self._num_message_passing_steps):
116+
self._processor_networks.append(
117+
gn.modules.InteractionNetwork(
118+
edge_model_fn=build_mlp_with_layer_norm,
119+
node_model_fn=build_mlp_with_layer_norm))
120+
121+
# The decoder MLP decodes node latent features into the output size.
122+
self._decoder_network = build_mlp(
123+
hidden_size=self._mlp_hidden_size,
124+
num_hidden_layers=self._mlp_num_hidden_layers,
125+
output_size=self._output_size)
126+
127+
def _encode(
128+
self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
129+
"""Encodes the input graph features into a latent graph."""
130+
131+
# Copy the globals to all of the nodes, if applicable.
132+
if input_graph.globals is not None:
133+
broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph)
134+
input_graph = input_graph.replace(
135+
nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1),
136+
globals=None)
137+
138+
# Encode the node and edge features.
139+
latent_graph_0 = self._encoder_network(input_graph)
140+
return latent_graph_0
141+
142+
def _process(
143+
self, latent_graph_0: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
144+
"""Processes the latent graph with several steps of message passing."""
145+
146+
# Do `m` message passing steps in the latent graphs.
147+
# (In the shared parameters case, just reuse the same `processor_network`)
148+
latent_graph_prev_k = latent_graph_0
149+
for processor_network_k in self._processor_networks:
150+
latent_graph_k = self._process_step(
151+
processor_network_k, latent_graph_prev_k)
152+
latent_graph_prev_k = latent_graph_k
153+
154+
latent_graph_m = latent_graph_k
155+
return latent_graph_m
156+
157+
def _process_step(
158+
self, processor_network_k: snt.Module,
159+
latent_graph_prev_k: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple:
160+
"""Single step of message passing with node/edge residual connections."""
161+
162+
# One step of message passing.
163+
latent_graph_k = processor_network_k(latent_graph_prev_k)
164+
165+
# Add residuals.
166+
latent_graph_k = latent_graph_k.replace(
167+
nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes,
168+
edges=latent_graph_k.edges+latent_graph_prev_k.edges)
169+
return latent_graph_k
170+
171+
def _decode(self, latent_graph: gn.graphs.GraphsTuple) -> tf.Tensor:
172+
"""Decodes from the latent graph."""
173+
return self._decoder_network(latent_graph.nodes)

0 commit comments

Comments
 (0)