STN在手写数字识别中的实践
Spatial Transformer Layer,空间变换网络,无需标定,即可矫正图像。接下来,我们在手写数字识别里面进行使用,以提高准确率。在开始之前需要跑通该手写数字识别代码,我们在这些代码上进行略微的改造。
准备知识
如果你对STN不太了解的话,先查看相关的论文解释。
搭建LocNet网络结构
LocNet的结构一般是卷积,池化,最后连接一个回归层,具体问题具体分析,如果图片尺寸比较大,先做卷积和池化,浓缩特征,如果图片的尺寸比较小,可以直接用全连接。
在model.py
中添加
def loc_net(inputs, keep_prob):
inputs = tf.reshape(inputs, [-1, 28, 28, 1])
with tf.variable_scope("loc_layer_1"):
w1 = get_weights([5, 5, 1, 32])
b1 = get_bias([32])
x = conv2d_relu(inputs, w1, b1, strides=[1, 2, 2, 1])
x = max_pool(x)
with tf.variable_scope("loc_layer_2"):
x = conv2d_relu(x, get_weights([3, 3, 32, 64]), get_bias([64]))
x = max_pool(x)
x = tf.reshape(x, [-1, 1024])
with tf.variable_scope("loc_layer_3"):
fc1 = tf.nn.tanh(tf.matmul(x, get_weights([1024, 32])) + get_bias([32]))
fc1 = tf.nn.dropout(fc1, keep_prob)
with tf.variable_scope("loc_layer_4"):
initial = np.array([[1., 0, 0], [0, 1., 0]])
initial = initial.astype('float32')
initial = initial.flatten()
w2 = get_weights([32, 6])
b2 = tf.Variable(initial_value=initial)
fc2 = tf.nn.tanh(tf.matmul(fc1, w2) + b2)
return fc2
这个方法定义了LocNet的前向传播。理论上网络结构不止一种。
需要注意的是最后一个偏置要这样定义!initial = np.array([[1., 0, 0], [0, 1., 0]])
,否则仿射变换出来的图像,会奇奇怪怪的。
Spatial Transformer Network layer实现
这里不再重复造轮子,我在github找到一个已经实现的项目,这里直接拿来用。非常感激原作者!
transformer.py
import tensorflow as tf
def spatial_transformer_network(input_fmap, theta, out_dims=None, **kwargs):
"""
Spatial Transformer Network layer implementation as described in [1].
The layer is composed of 3 elements:
- localization_net: takes the original image as input and outputs
the parameters of the affine transformation that should be applied
to the input image.
- affine_grid_generator: generates a grid of (x,y) coordinates that
correspond to a set of points where the input should be sampled
to produce the transformed output.
- bilinear_sampler: takes as input the original image and the grid
and produces the output image using bilinear interpolation.
Input
-----
- input_fmap: output of the previous layer. Can be input if spatial
transformer layer is at the beginning of architecture. Should be
a tensor of shape (B, H, W, C).
- theta: affine transform tensor of shape (B, 6). Permits cropping,
translation and isotropic scaling. Initialize to identity matrix.
It is the output of the localization network.
Returns
-------
- out_fmap: transformed input feature map. Tensor of size (B, H, W, C).
Notes
-----
[1]: 'Spatial Transformer Networks', Jaderberg et. al,
(https://arxiv.org/abs/1506.02025)
"""
# grab input dimensions
B = tf.shape(input_fmap)[0]
H = tf.shape(input_fmap)[1]
W = tf.shape(input_fmap)[2]
# reshape theta to (B, 2, 3)
theta = tf.reshape(theta, [B, 2, 3])
# generate grids of same size or upsample/downsample if specified
if out_dims:
out_H = out_dims[0]
out_W = out_dims[1]
batch_grids = affine_grid_generator(out_H, out_W, theta)
else:
batch_grids = affine_grid_generator(H, W, theta)
x_s = batch_grids[:, 0, :, :]
y_s = batch_grids[:, 1, :, :]
# sample input with grid to get output
out_fmap = bilinear_sampler(input_fmap, x_s, y_s)
return out_fmap
def get_pixel_value(img, x, y):
"""
Utility function to get pixel value for coordinate
vectors x and y from a 4D tensor image.
Input
-----
- img: tensor of shape (B, H, W, C)
- x: flattened tensor of shape (B*H*W,)
- y: flattened tensor of shape (B*H*W,)
Returns
-------
- output: tensor of shape (B, H, W, C)
"""
shape = tf.shape(x)
batch_size = shape[0]
height = shape[1]
width = shape[2]
batch_idx = tf.range(0, batch_size)
batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
b = tf.tile(batch_idx, (1, height, width))
indices = tf.stack([b, y, x], 3)
return tf.gather_nd(img, indices)
def affine_grid_generator(height, width, theta):
"""
This function returns a sampling grid, which when
used with the bilinear sampler on the input feature
map, will create an output feature map that is an
affine transformation [1] of the input feature map.
Input
-----
- height: desired height of grid/output. Used
to downsample or upsample.
- width: desired width of grid/output. Used
to downsample or upsample.
- theta: affine transform matrices of shape (num_batch, 2, 3).
For each image in the batch, we have 6 theta parameters of
the form (2x3) that define the affine transformation T.
Returns
-------
- normalized grid (-1, 1) of shape (num_batch, 2, H, W).
The 2nd dimension has 2 components: (x, y) which are the
sampling points of the original image for each point in the
target image.
Note
----
[1]: the affine transformation allows cropping, translation,
and isotropic scaling.
"""
num_batch = tf.shape(theta)[0]
# create normalized 2D grid
x = tf.linspace(-1.0, 1.0, width)
y = tf.linspace(-1.0, 1.0, height)
x_t, y_t = tf.meshgrid(x, y)
# flatten
x_t_flat = tf.reshape(x_t, [-1])
y_t_flat = tf.reshape(y_t, [-1])
# reshape to [x_t, y_t , 1] - (homogeneous form)
ones = tf.ones_like(x_t_flat)
sampling_grid = tf.stack([x_t_flat, y_t_flat, ones])
# repeat grid num_batch times
sampling_grid = tf.expand_dims(sampling_grid, axis=0)
sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))
# cast to float32 (required for matmul)
theta = tf.cast(theta, 'float32')
sampling_grid = tf.cast(sampling_grid, 'float32')
# transform the sampling grid - batch multiply
batch_grids = tf.matmul(theta, sampling_grid)
# batch grid has shape (num_batch, 2, H*W)
# reshape to (num_batch, H, W, 2)
batch_grids = tf.reshape(batch_grids, [num_batch, 2, height, width])
return batch_grids
def bilinear_sampler(img, x, y):
"""
Performs bilinear sampling of the input images according to the
normalized coordinates provided by the sampling grid. Note that
the sampling is done identically for each channel of the input.
To test if the function works properly, output image should be
identical to input image when theta is initialized to identity
transform.
Input
-----
- img: batch of images in (B, H, W, C) layout.
- grid: x, y which is the output of affine_grid_generator.
Returns
-------
- out: interpolated images according to grids. Same size as grid.
"""
H = tf.shape(img)[1]
W = tf.shape(img)[2]
max_y = tf.cast(H - 1, 'int32')
max_x = tf.cast(W - 1, 'int32')
zero = tf.zeros([], dtype='int32')
# rescale x and y to [0, W-1/H-1]
x = tf.cast(x, 'float32')
y = tf.cast(y, 'float32')
x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32'))
y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32'))
# grab 4 nearest corner points for each (x_i, y_i)
x0 = tf.cast(tf.floor(x), 'int32')
x1 = x0 + 1
y0 = tf.cast(tf.floor(y), 'int32')
y1 = y0 + 1
# clip to range [0, H-1/W-1] to not violate img boundaries
x0 = tf.clip_by_value(x0, zero, max_x)
x1 = tf.clip_by_value(x1, zero, max_x)
y0 = tf.clip_by_value(y0, zero, max_y)
y1 = tf.clip_by_value(y1, zero, max_y)
# get pixel value at corner coords
Ia = get_pixel_value(img, x0, y0)
Ib = get_pixel_value(img, x0, y1)
Ic = get_pixel_value(img, x1, y0)
Id = get_pixel_value(img, x1, y1)
# recast as float for delta calculation
x0 = tf.cast(x0, 'float32')
x1 = tf.cast(x1, 'float32')
y0 = tf.cast(y0, 'float32')
y1 = tf.cast(y1, 'float32')
# calculate deltas
wa = (x1-x) * (y1-y)
wb = (x1-x) * (y-y0)
wc = (x-x0) * (y1-y)
wd = (x-x0) * (y-y0)
# add dimension for addition
wa = tf.expand_dims(wa, axis=3)
wb = tf.expand_dims(wb, axis=3)
wc = tf.expand_dims(wc, axis=3)
wd = tf.expand_dims(wd, axis=3)
# compute output
out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])
return out
改造原来的前向传播
在forward方法头部插入该LocNet
theta = loc_net(input_data, loc_keep_prob)
input_data = tf.reshape(input_data, [-1, 28, 28, 1])
t_data = transformer.spatial_transformer_network(input_data, theta, out_dims=[28, 28])
参考图
改修train.py
和predict.py
需要注意下,forward方法多了一个loc_keep_prob参数,因此我们还要修改train.py
和predict.py
,由于改动很小,直接上图了;还有把STN的结果返回出来了,这样可以查看STN到底把图像变成什么样子了。
train.py
predict.py
效果图
上面是原图,下面是矫正后的。
本文系作者 @迦娜王 原创发布在 松鼠乐园。未经许可,禁止转载。