|
| 1 | +# Lint as: python3 |
| 2 | +# pylint: disable=g-bad-file-header |
| 3 | +# Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. |
| 4 | +# |
| 5 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | +# you may not use this file except in compliance with the License. |
| 7 | +# You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, software |
| 12 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | +# See the License for the specific language governing permissions and |
| 15 | +# limitations under the License. |
| 16 | +# ============================================================================ |
| 17 | + |
| 18 | +"""Graph network implementation accompanying ICML 2020 submission. |
| 19 | +
|
| 20 | + "Learning to Simulate Complex Physics with Graph Networks" |
| 21 | +
|
| 22 | + Alvaro Sanchez-Gonzalez*, Jonathan Godwin*, Tobias Pfaff*, Rex Ying, |
| 23 | + Jure Leskovec, Peter W. Battaglia |
| 24 | +
|
| 25 | + https://arxiv.org/abs/2002.09405 |
| 26 | +
|
| 27 | +The Sonnet `EncodeProcessDecode` module provided here implements the learnable |
| 28 | +parts of the model. |
| 29 | +It assumes an encoder preprocessor has already built a graph with |
| 30 | +connectivity and features as described in the paper, with features normalized |
| 31 | +to zero-mean unit-variance. |
| 32 | +
|
| 33 | +Dependencies include Tensorflow 1.x, Sonnet 1.x and the Graph Nets 1.1 library. |
| 34 | +""" |
| 35 | + |
| 36 | +import graph_nets as gn |
| 37 | +import sonnet as snt |
| 38 | +import tensorflow as tf |
| 39 | + |
| 40 | + |
| 41 | +def build_mlp( |
| 42 | + hidden_size: int, num_hidden_layers: int, output_size: int) -> snt.Module: |
| 43 | + """Builds an MLP.""" |
| 44 | + return snt.nets.MLP( |
| 45 | + output_sizes=[hidden_size] * num_hidden_layers + [output_size]) |
| 46 | + |
| 47 | + |
| 48 | +class EncodeProcessDecode(snt.AbstractModule): |
| 49 | + """Encode-Process-Decode function approximator for learnable simulator.""" |
| 50 | + |
| 51 | + def __init__( |
| 52 | + self, |
| 53 | + latent_size: int, |
| 54 | + mlp_hidden_size: int, |
| 55 | + mlp_num_hidden_layers: int, |
| 56 | + num_message_passing_steps: int, |
| 57 | + output_size: int, |
| 58 | + name: str = "EncodeProcessDecode"): |
| 59 | + """Inits the model. |
| 60 | +
|
| 61 | + Args: |
| 62 | + latent_size: Size of the node and edge latent representations. |
| 63 | + mlp_hidden_size: Hidden layer size for all MLPs. |
| 64 | + mlp_num_hidden_layers: Number of hidden layers in all MLPs. |
| 65 | + num_message_passing_steps: Number of message passing steps. |
| 66 | + output_size: Output size of the decode node representations as required |
| 67 | + by the downstream update function. |
| 68 | + name: Name of the model. |
| 69 | + """ |
| 70 | + |
| 71 | + super().__init__(name=name) |
| 72 | + |
| 73 | + self._latent_size = latent_size |
| 74 | + self._mlp_hidden_size = mlp_hidden_size |
| 75 | + self._mlp_num_hidden_layers = mlp_num_hidden_layers |
| 76 | + self._num_message_passing_steps = num_message_passing_steps |
| 77 | + self._output_size = output_size |
| 78 | + |
| 79 | + with self._enter_variable_scope(): |
| 80 | + self._networks_builder() |
| 81 | + |
| 82 | + def _build(self, input_graph: gn.graphs.GraphsTuple) -> tf.Tensor: |
| 83 | + """Forward pass of the learnable dynamics model.""" |
| 84 | + |
| 85 | + # Encode the input_graph. |
| 86 | + latent_graph_0 = self._encode(input_graph) |
| 87 | + |
| 88 | + # Do `m` message passing steps in the latent graphs. |
| 89 | + latent_graph_m = self._process(latent_graph_0) |
| 90 | + |
| 91 | + # Decode from the last latent graph. |
| 92 | + return self._decode(latent_graph_m) |
| 93 | + |
| 94 | + def _networks_builder(self): |
| 95 | + """Builds the networks.""" |
| 96 | + |
| 97 | + def build_mlp_with_layer_norm(): |
| 98 | + mlp = build_mlp( |
| 99 | + hidden_size=self._mlp_hidden_size, |
| 100 | + num_hidden_layers=self._mlp_num_hidden_layers, |
| 101 | + output_size=self._latent_size) |
| 102 | + return snt.Sequential([mlp, snt.LayerNorm()]) |
| 103 | + |
| 104 | + # The encoder graph network independently encodes edge and node features. |
| 105 | + encoder_kwargs = dict( |
| 106 | + edge_model_fn=build_mlp_with_layer_norm, |
| 107 | + node_model_fn=build_mlp_with_layer_norm) |
| 108 | + self._encoder_network = gn.modules.GraphIndependent(**encoder_kwargs) |
| 109 | + |
| 110 | + # Create `num_message_passing_steps` graph networks with unshared parameters |
| 111 | + # that update the node and edge latent features. |
| 112 | + # Note that we can use `modules.InteractionNetwork` because |
| 113 | + # it also outputs the messages as updated edge latent features. |
| 114 | + self._processor_networks = [] |
| 115 | + for _ in range(self._num_message_passing_steps): |
| 116 | + self._processor_networks.append( |
| 117 | + gn.modules.InteractionNetwork( |
| 118 | + edge_model_fn=build_mlp_with_layer_norm, |
| 119 | + node_model_fn=build_mlp_with_layer_norm)) |
| 120 | + |
| 121 | + # The decoder MLP decodes node latent features into the output size. |
| 122 | + self._decoder_network = build_mlp( |
| 123 | + hidden_size=self._mlp_hidden_size, |
| 124 | + num_hidden_layers=self._mlp_num_hidden_layers, |
| 125 | + output_size=self._output_size) |
| 126 | + |
| 127 | + def _encode( |
| 128 | + self, input_graph: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: |
| 129 | + """Encodes the input graph features into a latent graph.""" |
| 130 | + |
| 131 | + # Copy the globals to all of the nodes, if applicable. |
| 132 | + if input_graph.globals is not None: |
| 133 | + broadcasted_globals = gn.blocks.broadcast_globals_to_nodes(input_graph) |
| 134 | + input_graph = input_graph.replace( |
| 135 | + nodes=tf.concat([input_graph.nodes, broadcasted_globals], axis=-1), |
| 136 | + globals=None) |
| 137 | + |
| 138 | + # Encode the node and edge features. |
| 139 | + latent_graph_0 = self._encoder_network(input_graph) |
| 140 | + return latent_graph_0 |
| 141 | + |
| 142 | + def _process( |
| 143 | + self, latent_graph_0: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: |
| 144 | + """Processes the latent graph with several steps of message passing.""" |
| 145 | + |
| 146 | + # Do `m` message passing steps in the latent graphs. |
| 147 | + # (In the shared parameters case, just reuse the same `processor_network`) |
| 148 | + latent_graph_prev_k = latent_graph_0 |
| 149 | + for processor_network_k in self._processor_networks: |
| 150 | + latent_graph_k = self._process_step( |
| 151 | + processor_network_k, latent_graph_prev_k) |
| 152 | + latent_graph_prev_k = latent_graph_k |
| 153 | + |
| 154 | + latent_graph_m = latent_graph_k |
| 155 | + return latent_graph_m |
| 156 | + |
| 157 | + def _process_step( |
| 158 | + self, processor_network_k: snt.Module, |
| 159 | + latent_graph_prev_k: gn.graphs.GraphsTuple) -> gn.graphs.GraphsTuple: |
| 160 | + """Single step of message passing with node/edge residual connections.""" |
| 161 | + |
| 162 | + # One step of message passing. |
| 163 | + latent_graph_k = processor_network_k(latent_graph_prev_k) |
| 164 | + |
| 165 | + # Add residuals. |
| 166 | + latent_graph_k = latent_graph_k.replace( |
| 167 | + nodes=latent_graph_k.nodes+latent_graph_prev_k.nodes, |
| 168 | + edges=latent_graph_k.edges+latent_graph_prev_k.edges) |
| 169 | + return latent_graph_k |
| 170 | + |
| 171 | + def _decode(self, latent_graph: gn.graphs.GraphsTuple) -> tf.Tensor: |
| 172 | + """Decodes from the latent graph.""" |
| 173 | + return self._decoder_network(latent_graph.nodes) |
0 commit comments