Skip to content

Commit 1fbf637

Browse files
committed
【add】highway_conv2d-layer
1 parent 07df6ff commit 1fbf637

File tree

6 files changed

+617
-11
lines changed

6 files changed

+617
-11
lines changed

code/my_tensorflow/demo.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
from typing import (
2+
TypeVar, Iterator, Iterable, overload, Container,
3+
Sequence, MutableSequence, Mapping, MutableMapping, Tuple, List, Any, Dict, Callable, Generic,
4+
Set, AbstractSet, FrozenSet, MutableSet, Sized, Reversible, SupportsInt, SupportsFloat,
5+
SupportsBytes, SupportsAbs, SupportsRound, IO, Union, ItemsView, KeysView, ValuesView,
6+
ByteString, Optional, AnyStr, Type,
7+
)
18
from src.utils import tf_dtype
29

310
print(tf_dtype)
@@ -6,8 +13,10 @@
613
import numpy as np
714

815
from src.layers import multi_dense, Dense
9-
from tensorlayer.layers import DenseLayer
10-
from keras.layers import Dense
16+
17+
from tensorlayer.layers import DenseLayer, Conv2dLayer
18+
19+
from keras.layers import Dense, Conv2D
1120

1221
import keras.initializers
1322

@@ -25,3 +34,22 @@
2534

2635
sess.run(o)
2736

37+
38+
def f(x):
39+
"""
40+
41+
Args:
42+
x(function):
43+
44+
Returns:
45+
46+
"""
47+
48+
49+
def b():
50+
""""""
51+
52+
53+
f(1)
54+
55+
sorted()
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
1+
"""常见层与常用层的写法
2+
"""
3+
14
from .dense import *
5+
from .cnn import *
26
from .highway import *
7+

code/my_tensorflow/src/layers/cnn.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""卷积层
2+
3+
References:
4+
tensorlayer.layers.Conv2dLayer
5+
"""
6+
import tensorflow as tf
7+
8+
from ..activations import relu
9+
from ..utils import get_wb
10+
11+
12+
def conv2d(x, kernel_size, out_channels,
13+
act_fn=relu,
14+
strides=1,
15+
padding="SAME",
16+
name=None,
17+
reuse=False):
18+
"""2-D 卷积层
19+
Input shape: [batch_size, in_w, in_h, in_channels]
20+
Output shape: [batch_size, out_w, out_h, out_channels]
21+
22+
Args:
23+
x(tf.Tensor):
24+
kernel_size(int or list of int):
25+
out_channels(int):
26+
act_fn(function):
27+
strides(int or list of int):
28+
padding(str):
29+
name(str):
30+
reuse(bool):
31+
32+
Returns:
33+
34+
"""
35+
if isinstance(kernel_size, int):
36+
kernel_size = [kernel_size] * 2
37+
if isinstance(strides, int):
38+
strides = [strides] * 4
39+
40+
assert len(kernel_size) == 2
41+
assert len(strides) == 4
42+
43+
in_channels = int(x.get_shape()[-1])
44+
kernel_shape = list(kernel_size) + [in_channels, out_channels]
45+
46+
with tf.variable_scope(name or "conv2d", reuse=reuse):
47+
W, b = get_wb(kernel_shape)
48+
49+
o = tf.nn.conv2d(x, W, strides=strides, padding=padding) + b
50+
o = act_fn(o)
51+
52+
return o
53+

code/my_tensorflow/src/layers/highway.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
"""高速网络
1+
"""高速网络 Highway Network
2+
3+
注意 x 经过 Highway 之后维度应该保持不变
24
35
References:
46
https://github.com/fomorians/highway-fcn
@@ -18,27 +20,66 @@ def highway_dense(x, act_fn=relu, carry_bias=-1.0, name=None):
1820
`o = H(x, W)T(x, W) + x(1 - T(x, W))`
1921
"""
2022
n_input = int(x.get_shape()[-1])
21-
with tf.variable_scope(name or "highway"):
23+
with tf.variable_scope(name or "highway_dense"):
2224
W, b = get_wb([n_input, n_input])
2325

2426
with tf.variable_scope("transform"):
2527
W_T, b_T = get_wb([n_input, n_input], b_initializer=tf.initializers.constant(carry_bias))
2628

2729
H = act_fn(tf.matmul(x, W) + b)
2830
T = sigmoid(tf.matmul(x, W_T) + b_T)
29-
3031
o = tf.multiply(H, T) + tf.multiply(x, (1. - T))
3132

3233
return o
3334

3435

3536
def multi_highway_dense(x, n_layer, act_fn=relu, carry_bias=-1.0, name=None):
36-
"""多层 highway
37+
"""多层 highway_dense
3738
Input shape: [batch_size, n_input]
3839
Output shape: [batch_size, n_input]
3940
"""
40-
name = name or "highway"
41+
name = name or "highway_dense"
4142
for i in range(n_layer):
4243
x = highway_dense(x, act_fn=act_fn, carry_bias=carry_bias, name="{}-{}".format(name, i))
4344

4445
return x
46+
47+
48+
def highway_conv2d(x, kernel_size,
49+
carry_bias=-1.0,
50+
act_fn=relu,
51+
strides=1,
52+
padding="SAME",
53+
name=None):
54+
"""用于 conv2d 的 highway
55+
Input shape: [batch_size, in_w, in_h, in_channels]
56+
Output shape: [batch_size, in_w, in_w, in_channels]
57+
58+
公式
59+
`o = H(x, W)T(x, W) + x(1 - T(x, W))`
60+
"""
61+
if isinstance(kernel_size, int):
62+
kernel_size = [kernel_size] * 2
63+
if isinstance(strides, int):
64+
strides = [strides] * 4
65+
66+
assert len(kernel_size) == 2
67+
assert len(strides) == 4
68+
69+
in_channels = int(x.get_shape()[-1])
70+
kernel_shape = list(kernel_size) + [in_channels, in_channels]
71+
72+
with tf.variable_scope(name or "highway_conv2d"):
73+
W, b = get_wb(kernel_shape, b_initializer=tf.initializers.constant(carry_bias))
74+
75+
with tf.variable_scope("transform"):
76+
W_T, b_T = get_wb(kernel_shape)
77+
78+
H = act_fn(tf.nn.conv2d(x, W, strides=strides, padding=padding) + b)
79+
T = sigmoid(tf.nn.conv2d(x, W_T, strides=strides, padding=padding) + b_T)
80+
o = tf.multiply(H, T) + tf.multiply(x, (1. - T))
81+
return o
82+
83+
84+
# TODO(huay): 因为卷积的参数比较复杂,如果需要多层 highway_conv2d,不如单独设置参数
85+
# def multi_highway_conv2d():

code/my_tensorflow/src/utils/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@ def get_wb(shape,
2323
w_regularizer=l2_regularizer,
2424
b_regularizer=l2_regularizer):
2525
""""""
26-
n_in, n_unit = shape
27-
W = tf.get_variable('W', shape=[n_in, n_unit],
26+
W = tf.get_variable('W', shape=shape,
2827
dtype=tf_dtype, initializer=w_initializer, regularizer=w_regularizer)
29-
b = tf.get_variable('b', shape=[n_unit],
28+
b = tf.get_variable('b', shape=shape[-1:],
3029
dtype=tf_dtype, initializer=b_initializer, regularizer=b_regularizer)
3130
return W, b
3231

@@ -48,7 +47,7 @@ def get_params_dict():
4847
return param_dict
4948

5049

51-
def print_get_params_dict():
50+
def print_params_dict():
5251
""""""
5352
param_dict = get_params_dict()
5453
# pprint(param_dict, indent=2)

0 commit comments

Comments
 (0)