Skip to content

Commit df18831

Browse files
committed
Add padding arround image
1 parent 61e0c5a commit df18831

File tree

3 files changed

+27
-17
lines changed

3 files changed

+27
-17
lines changed

SourceCode/ConvLSTM/BasicConvLSTMCell.py

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
21
import tensorflow as tf
3-
4-
2+
from tensorflow.contrib.layers.python.layers.layers import layer_norm
53

64
class ConvRNNCell(object):
75
"""Abstract object representing an Convolutional RNN cell.
86
"""
7+
98
def __call__(self, inputs, state, scope=None):
109
"""Run this RNN cell on inputs, starting from the given state.
1110
"""
@@ -37,10 +36,10 @@ def zero_state(self, batch_size):
3736
if self._state_is_tuple:
3837
if self._data_format == 'NHWC':
3938
zeros = (tf.zeros([batch_size, shape[0], shape[1], num_features]),
40-
tf.zeros([batch_size, shape[0], shape[1], num_features]))
39+
tf.zeros([batch_size, shape[0], shape[1], num_features]))
4140
else:
4241
zeros = (tf.zeros([batch_size, num_features, shape[0], shape[1]]),
43-
tf.zeros([batch_size, num_features, shape[0], shape[1]]))
42+
tf.zeros([batch_size, num_features, shape[0], shape[1]]))
4443
else:
4544
if self._data_format == 'NHWC':
4645
zeros = tf.zeros([batch_size, shape[0], shape[1], num_features * 2])
@@ -110,6 +109,7 @@ def __call__(self, inputs, state, scope=None, reuse=None, clip_cell=None):
110109
new_state = tf.concat(axis=channel_axis, values=[new_c, new_h])
111110
return new_h, new_state
112111

112+
113113
class LayerNormConvLSTMCell(ConvRNNCell):
114114
"""Basic Conv LSTM recurrent network cell. The
115115
"""
@@ -155,23 +155,36 @@ def __call__(self, inputs, state, scope=None, reuse=None):
155155
else:
156156
c, h = tf.split(axis=channel_axis, num_or_size_splits=2, value=state)
157157
concat_i = _conv_linear([inputs], self.filter_size, self.num_features * 4, False,
158-
data_format=self._data_format, scope='Input')
158+
data_format=self._data_format, scope='Input')
159159
concat_h = _conv_linear([inputs], self.filter_size, self.num_features * 4, False,
160-
data_format=self._data_format, scope='H')
160+
data_format=self._data_format, scope='H')
161161
bias_term = tf.get_variable("Bias", [self.num_features * 4], dtype=inputs.dtype,
162162
initializer=tf.constant_initializer(0.0, dtype=inputs.dtype))
163-
concat = tf.nn.bias_add(tf.contrib.layers.layer_norm(concat_i) + tf.contrib.layers.layer_norm(concat_h),
163+
if self._data_format == 'NCHW':
164+
concat_i = tf.transpose(concat_i, (0, 2, 3, 1))
165+
concat_h = tf.transpose(concat_h, (0, 2, 3, 1))
166+
normed_i = layer_norm(concat_i)
167+
normed_h = layer_norm(concat_h)
168+
if self._data_format == 'NCHW':
169+
normed_i = tf.transpose(normed_i, (0, 3, 1, 2))
170+
normed_h = tf.transpose(normed_h, (0, 3, 1, 2))
171+
172+
concat = tf.nn.bias_add(normed_i + normed_h,
164173
bias_term, data_format=self._data_format)
165174
# i = input_gate, j = new_input, f = forget_gate, o = output_gate
166175
i, j, f, o = tf.split(axis=channel_axis, num_or_size_splits=4, value=concat)
167176
new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * self._activation(j))
168-
new_c_norm = tf.contrib.layers.layer_norm(new_c)
177+
if self._data_format == 'NCHW':
178+
new_c = tf.transpose(new_c, (0, 2, 3, 1))
179+
new_c_norm = layer_norm(new_c)
180+
if self._data_format == 'NCHW':
181+
new_c_norm = tf.transpose(new_c_norm, (0, 3, 1, 2))
169182
new_h = self._activation(new_c_norm) * tf.nn.sigmoid(o)
170183

171184
if self._state_is_tuple:
172-
new_state = (new_c, new_h)
185+
new_state = (new_c_norm, new_h)
173186
else:
174-
new_state = tf.concat(axis=channel_axis, values=[new_c, new_h])
187+
new_state = tf.concat(axis=channel_axis, values=[new_c_norm, new_h])
175188
return new_h, new_state
176189

177190

@@ -261,13 +274,11 @@ def __call__(self, inputs, h, scope=None, reuse=None):
261274
channel_axis = 3 if self._data_format == 'NHWC' else 1
262275

263276
concat = tf.nn.sigmoid(_conv_linear([inputs, h], self.filter_size, self.num_features * 2, True,
264-
data_format=self._data_format, scope='gates'))
277+
data_format=self._data_format, scope='gates'))
265278

266279
z, r = tf.split(axis=channel_axis, num_or_size_splits=2, value=concat)
267280
i = tf.nn.tanh(_conv_linear([inputs, tf.multiply(r, h)], self.filter_size, self.num_features, True,
268281
data_format=self._data_format, scope='input'))
269-
new_h = tf.add(tf.multiply(z, h), tf.multiply(1-z, i))
282+
new_h = tf.add(tf.multiply(z, h), tf.multiply(1 - z, i))
270283

271284
return new_h
272-
273-

SourceCode/DataHandeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import threading
88
import numpy as np
99
import pandas as pd
10-
pd.__version__
1110

1211
# import numpy as np
1312
# import matplotlib.pyplot as plt
@@ -953,6 +952,7 @@ def tif2png_dir(data_dir: str, out_dir: str, filename_format='t*.tif'):
953952
pad_x = 8-(img_size[1] % 8)
954953
else:
955954
pad_x = 0
955+
img = cv2.copyMakeBorder(img, 16, 16, 16, 16, cv2.BORDER_REFLECT_101)
956956
if pad_x or pad_y:
957957
img = cv2.copyMakeBorder(img, 0, pad_y, 0, pad_x, cv2.BORDER_REFLECT_101)
958958
base_name = os.path.basename(tif_filename)

SourceCode/Layers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import tensorflow as tf
22
from tensorflow.contrib.layers.python.layers.layers import layer_norm
33

4-
tf.load_op_library()
54
__author__ = 'assafarbelle'
65

76

0 commit comments

Comments
 (0)