松鼠乐园 松鼠乐园
  • 注册
  • 登录
  • 首页
  • 快捷入口
    • Vue
    • Tensorflow
    • Springboot
    • 语言类
      • CSS
      • ES5
      • ES6
      • Go
      • Java
      • Javascript
    • 工具类
      • Git
      • 工具推荐
    • 服务器&运维
      • Centos
      • Docker
      • Linux
      • Mac
      • MySQL
      • Nginx
      • Redis
      • Windows
    • 资源类
      • 论文
      • 书籍推荐
      • 后端资源
      • 前端资源
      • html网页模板
      • 代码
    • 性能优化
    • 测试
  • 重大新闻
  • 人工智能
  • 开源项目
  • Vue2.0从零开始
  • 广场
首页 › Tensorflow › STN在手写数字识别中的实践

STN在手写数字识别中的实践

迦娜王
3年前Tensorflow
2,621 0 7

Spatial Transformer Layer,空间变换网络,无需标定,即可矫正图像。接下来,我们在手写数字识别里面进行使用,以提高准确率。在开始之前需要跑通该手写数字识别代码,我们在这些代码上进行略微的改造。

准备知识

如果你对STN不太了解的话,先查看相关的论文解释。

搭建LocNet网络结构

LocNet的结构一般是卷积,池化,最后连接一个回归层,具体问题具体分析,如果图片尺寸比较大,先做卷积和池化,浓缩特征,如果图片的尺寸比较小,可以直接用全连接。

在model.py中添加


def loc_net(inputs, keep_prob):
    inputs = tf.reshape(inputs, [-1, 28, 28, 1])
    with tf.variable_scope("loc_layer_1"):
        w1 = get_weights([5, 5, 1, 32])
        b1 = get_bias([32])
        x = conv2d_relu(inputs, w1, b1, strides=[1, 2, 2, 1])
        x = max_pool(x)

    with tf.variable_scope("loc_layer_2"):
        x = conv2d_relu(x, get_weights([3, 3, 32, 64]), get_bias([64]))
        x = max_pool(x)
        x = tf.reshape(x, [-1, 1024])

    with tf.variable_scope("loc_layer_3"):
        fc1 = tf.nn.tanh(tf.matmul(x, get_weights([1024, 32])) + get_bias([32]))
        fc1 = tf.nn.dropout(fc1, keep_prob)

    with tf.variable_scope("loc_layer_4"):
        initial = np.array([[1., 0, 0], [0, 1., 0]])
        initial = initial.astype('float32')
        initial = initial.flatten()
        w2 = get_weights([32, 6])
        b2 = tf.Variable(initial_value=initial)
        fc2 = tf.nn.tanh(tf.matmul(fc1, w2) + b2)
    return fc2

这个方法定义了LocNet的前向传播。理论上网络结构不止一种。
需要注意的是最后一个偏置要这样定义!initial = np.array([[1., 0, 0], [0, 1., 0]]),否则仿射变换出来的图像,会奇奇怪怪的。

Spatial Transformer Network layer实现

这里不再重复造轮子,我在github找到一个已经实现的项目,这里直接拿来用。非常感激原作者!
transformer.py

import tensorflow as tf

def spatial_transformer_network(input_fmap, theta, out_dims=None, **kwargs):
    """
    Spatial Transformer Network layer implementation as described in [1].

    The layer is composed of 3 elements:

    - localization_net: takes the original image as input and outputs
      the parameters of the affine transformation that should be applied
      to the input image.

    - affine_grid_generator: generates a grid of (x,y) coordinates that
      correspond to a set of points where the input should be sampled
      to produce the transformed output.

    - bilinear_sampler: takes as input the original image and the grid
      and produces the output image using bilinear interpolation.

    Input
    -----
    - input_fmap: output of the previous layer. Can be input if spatial
      transformer layer is at the beginning of architecture. Should be
      a tensor of shape (B, H, W, C).

    - theta: affine transform tensor of shape (B, 6). Permits cropping,
      translation and isotropic scaling. Initialize to identity matrix.
      It is the output of the localization network.

    Returns
    -------
    - out_fmap: transformed input feature map. Tensor of size (B, H, W, C).

    Notes
    -----
    [1]: 'Spatial Transformer Networks', Jaderberg et. al,
         (https://arxiv.org/abs/1506.02025)

    """
    # grab input dimensions
    B = tf.shape(input_fmap)[0]
    H = tf.shape(input_fmap)[1]
    W = tf.shape(input_fmap)[2]

    # reshape theta to (B, 2, 3)
    theta = tf.reshape(theta, [B, 2, 3])

    # generate grids of same size or upsample/downsample if specified
    if out_dims:
        out_H = out_dims[0]
        out_W = out_dims[1]
        batch_grids = affine_grid_generator(out_H, out_W, theta)
    else:
        batch_grids = affine_grid_generator(H, W, theta)

    x_s = batch_grids[:, 0, :, :]
    y_s = batch_grids[:, 1, :, :]

    # sample input with grid to get output
    out_fmap = bilinear_sampler(input_fmap, x_s, y_s)

    return out_fmap

def get_pixel_value(img, x, y):
    """
    Utility function to get pixel value for coordinate
    vectors x and y from a  4D tensor image.

    Input
    -----
    - img: tensor of shape (B, H, W, C)
    - x: flattened tensor of shape (B*H*W,)
    - y: flattened tensor of shape (B*H*W,)

    Returns
    -------
    - output: tensor of shape (B, H, W, C)
    """
    shape = tf.shape(x)
    batch_size = shape[0]
    height = shape[1]
    width = shape[2]

    batch_idx = tf.range(0, batch_size)
    batch_idx = tf.reshape(batch_idx, (batch_size, 1, 1))
    b = tf.tile(batch_idx, (1, height, width))

    indices = tf.stack([b, y, x], 3)

    return tf.gather_nd(img, indices)

def affine_grid_generator(height, width, theta):
    """
    This function returns a sampling grid, which when
    used with the bilinear sampler on the input feature
    map, will create an output feature map that is an
    affine transformation [1] of the input feature map.

    Input
    -----
    - height: desired height of grid/output. Used
      to downsample or upsample.

    - width: desired width of grid/output. Used
      to downsample or upsample.

    - theta: affine transform matrices of shape (num_batch, 2, 3).
      For each image in the batch, we have 6 theta parameters of
      the form (2x3) that define the affine transformation T.

    Returns
    -------
    - normalized grid (-1, 1) of shape (num_batch, 2, H, W).
      The 2nd dimension has 2 components: (x, y) which are the
      sampling points of the original image for each point in the
      target image.

    Note
    ----
    [1]: the affine transformation allows cropping, translation,
         and isotropic scaling.
    """
    num_batch = tf.shape(theta)[0]

    # create normalized 2D grid
    x = tf.linspace(-1.0, 1.0, width)
    y = tf.linspace(-1.0, 1.0, height)
    x_t, y_t = tf.meshgrid(x, y)

    # flatten
    x_t_flat = tf.reshape(x_t, [-1])
    y_t_flat = tf.reshape(y_t, [-1])

    # reshape to [x_t, y_t , 1] - (homogeneous form)
    ones = tf.ones_like(x_t_flat)
    sampling_grid = tf.stack([x_t_flat, y_t_flat, ones])

    # repeat grid num_batch times
    sampling_grid = tf.expand_dims(sampling_grid, axis=0)
    sampling_grid = tf.tile(sampling_grid, tf.stack([num_batch, 1, 1]))

    # cast to float32 (required for matmul)
    theta = tf.cast(theta, 'float32')
    sampling_grid = tf.cast(sampling_grid, 'float32')

    # transform the sampling grid - batch multiply
    batch_grids = tf.matmul(theta, sampling_grid)
    # batch grid has shape (num_batch, 2, H*W)

    # reshape to (num_batch, H, W, 2)
    batch_grids = tf.reshape(batch_grids, [num_batch, 2, height, width])

    return batch_grids

def bilinear_sampler(img, x, y):
    """
    Performs bilinear sampling of the input images according to the
    normalized coordinates provided by the sampling grid. Note that
    the sampling is done identically for each channel of the input.

    To test if the function works properly, output image should be
    identical to input image when theta is initialized to identity
    transform.

    Input
    -----
    - img: batch of images in (B, H, W, C) layout.
    - grid: x, y which is the output of affine_grid_generator.

    Returns
    -------
    - out: interpolated images according to grids. Same size as grid.
    """
    H = tf.shape(img)[1]
    W = tf.shape(img)[2]
    max_y = tf.cast(H - 1, 'int32')
    max_x = tf.cast(W - 1, 'int32')
    zero = tf.zeros([], dtype='int32')

    # rescale x and y to [0, W-1/H-1]
    x = tf.cast(x, 'float32')
    y = tf.cast(y, 'float32')
    x = 0.5 * ((x + 1.0) * tf.cast(max_x-1, 'float32'))
    y = 0.5 * ((y + 1.0) * tf.cast(max_y-1, 'float32'))

    # grab 4 nearest corner points for each (x_i, y_i)
    x0 = tf.cast(tf.floor(x), 'int32')
    x1 = x0 + 1
    y0 = tf.cast(tf.floor(y), 'int32')
    y1 = y0 + 1

    # clip to range [0, H-1/W-1] to not violate img boundaries
    x0 = tf.clip_by_value(x0, zero, max_x)
    x1 = tf.clip_by_value(x1, zero, max_x)
    y0 = tf.clip_by_value(y0, zero, max_y)
    y1 = tf.clip_by_value(y1, zero, max_y)

    # get pixel value at corner coords
    Ia = get_pixel_value(img, x0, y0)
    Ib = get_pixel_value(img, x0, y1)
    Ic = get_pixel_value(img, x1, y0)
    Id = get_pixel_value(img, x1, y1)

    # recast as float for delta calculation
    x0 = tf.cast(x0, 'float32')
    x1 = tf.cast(x1, 'float32')
    y0 = tf.cast(y0, 'float32')
    y1 = tf.cast(y1, 'float32')

    # calculate deltas
    wa = (x1-x) * (y1-y)
    wb = (x1-x) * (y-y0)
    wc = (x-x0) * (y1-y)
    wd = (x-x0) * (y-y0)

    # add dimension for addition
    wa = tf.expand_dims(wa, axis=3)
    wb = tf.expand_dims(wb, axis=3)
    wc = tf.expand_dims(wc, axis=3)
    wd = tf.expand_dims(wd, axis=3)

    # compute output
    out = tf.add_n([wa*Ia, wb*Ib, wc*Ic, wd*Id])

    return out

改造原来的前向传播

在forward方法头部插入该LocNet

theta = loc_net(input_data, loc_keep_prob)
input_data = tf.reshape(input_data, [-1, 28, 28, 1])
t_data = transformer.spatial_transformer_network(input_data, theta, out_dims=[28, 28])

参考图
file

改修train.py和predict.py

需要注意下,forward方法多了一个loc_keep_prob参数,因此我们还要修改train.py和predict.py,由于改动很小,直接上图了;还有把STN的结果返回出来了,这样可以查看STN到底把图像变成什么样子了。

train.py
file

predict.py
file

效果图

上面是原图,下面是矫正后的。
file

7
本文系作者 @迦娜王 原创发布在 松鼠乐园。未经许可,禁止转载。
Spatial Transformer Networks
上一篇
模型评估的性能指标
下一篇
评论 (0)

请登录以参与评论。

现在登录
  • 准备知识
  • 搭建LocNet网络结构
  • Spatial Transformer Network layer实现
  • 改造原来的前向传播
  • 改修train.py和predict.py
  • 效果图
7
相关文章
我收集了12款自动生成器,无聊人士自娱自乐专用
输入一张图,就能让二次元老婆动起来,宛如3D:这全是为了科学啊
使用ONNX+TensorRT部署人脸检测和关键点250fps
基于 Keras 的烟火检测
松鼠乐园

资源整合,创造价值

小伙伴
墨魇博客 无同创意
目录
重大新闻 Centos CSS Docker ES5 ES6 Go Java Javascript Linux Mac MySQL Nginx Redis Springboot Tensorflow Vue Vue2.x从零开始 Windows 书籍推荐 人工智能 前端资源 后端资源 壁纸 开源项目 测试 论文
Copyright © 2018-2022 松鼠乐园. Designed by nicetheme. 浙ICP备15039601号-4
  • 重大新闻
  • Centos
  • CSS
  • Docker
  • ES5
  • ES6
  • Go
  • Java
  • Javascript
  • Linux
  • Mac
  • MySQL
  • Nginx
  • Redis
  • Springboot
  • Tensorflow
  • Vue
  • Vue2.x从零开始
  • Windows
  • 书籍推荐
  • 人工智能
  • 前端资源
  • 后端资源
  • 壁纸
  • 开源项目
  • 测试
  • 论文
热门搜索
  • jetson nano
  • vue
  • java
  • mysql
  • 人工智能
  • 人脸识别
迦娜王
坚持才有希望
1224 文章
35 评论
242 喜欢
  • 7
  • 0
  • Top