当先锋百科网

首页 1 2 3 4 5 6 7

编写你自己的 Keras 层
对于简单、无状态的自定义操作,你也许可以通过 layers.core.Lambda 层来实现。但是对于那些包含了可训练权重的自定义层,你应该自己实现这种层。

这是一个 Keras2.0 中,Keras 层的骨架(如果你用的是旧的版本,请更新到新版)。你只需要实现三个方法即可:

build(input_shape): 这是你定义权重的地方。这个方法必须设 self.built = True,可以通过调用 super([Layer], self).build() 完成。
call(x): 这里是编写层的功能逻辑的地方。你只需要关注传入 call 的第一个参数:输入张量,除非你希望你的层支持masking。
compute_output_shape(input_shape): 如果你的层更改了输入张量的形状,你应该在这里定义形状变化的逻辑,这让Keras能够自动推断各层的形状。
from keras import backend as K
from keras.engine.topology import Layer

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # 为该层创建一个可训练的权重
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # 一定要在最后调用它

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
还可以定义具有多个输入张量和多个输出张量的 Keras 层。 为此,你应该假设方法 build(input_shape),call(x) 和 compute_output_shape(input_shape) 的输入输出都是列表。 这里是一个例子,与上面那个相似:

from keras import backend as K
from keras.engine.topology import Layer

class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        # 为该层创建一个可训练的权重
        self.kernel = self.add_weight(name='kernel',
                                      shape=(input_shape[0][1], self.output_dim),
                                      initializer='uniform',
                                      trainable=True)
        super(MyLayer, self).build(input_shape)  # 一定要在最后调用它

    def call(self, x):
        assert isinstance(x, list)
        a, b = x
        return [K.dot(a, self.kernel) + b, K.mean(b, axis=-1)]

    def compute_output_shape(self, input_shape):
        assert isinstance(input_shape, list)
        shape_a, shape_b = input_shape
        return [(shape_a[0], self.output_dim), shape_b[:-1]]
shape_a[0] 表示Samples number


Keras Dense Layer:

def call(self, inputs):
    output = K.dot(inputs, self.kernel)
    if self.use_bias:
        #bakend operation https://keras-cn.readthedocs.io/en/latest/backend/
        output = K.bias_add(output, self.bias, data_format='channels_last')
    if self.activation is not None:
        output = self.activation(output)
    return output

def compute_output_shape(self, input_shape):
    assert input_shape and len(input_shape) >= 2
    assert input_shape[-1]
    output_shape = list(input_shape)
    output_shape[-1] = self.units # -1 表示 last element, output demension should be equals to units number
    return tuple(output_shape)


keras Lambda自定义层实现数据的切片,Lambda传参数

import numpy as np
from keras.models import Sequential
from keras.layers import Dense, Activation,Reshape
from keras.layers import merge
from keras.utils.visualize_util import plot
from keras.layers import Input, Lambda
from keras.models import Model
 
def slice(x,index):
        return x[:,:,index]
 
a = Input(shape=(4,2))
x1 = Lambda(slice,output_shape=(4,1),arguments={'index':0})(a)
x2 = Lambda(slice,output_shape=(4,1),arguments={'index':1})(a)
x1 = Reshape((4,1,1))(x1)
x2 = Reshape((4,1,1))(x2)
output = merge([x1,x2],mode='concat')
model = Model(a, output)
x_test = np.array([[[1,2],[2,3],[3,4],[4,5]]])
print model.predict(x_test)
plot(model, to_file='lambda.png',show_shapes=True)

 

 

None means samples number

Example 1 : Cropping2D Layer

class Cropping2D(Layer):

    def __init__(self, cropping=((0, 0), (0, 0)), dim_ordering='default', **kwargs):
        super(Cropping2D, self).__init__(**kwargs)
        if dim_ordering == 'default':
            dim_ordering = K.image_dim_ordering()
        self.cropping = tuple(cropping)
        assert len(self.cropping) == 2, 'cropping must be a tuple length of 2'
        assert len(self.cropping[0]) == 2, 'cropping[0] must be a tuple length of 2'
        assert len(self.cropping[1]) == 2, 'cropping[1] must be a tuple length of 2'
        assert dim_ordering in {'tf', 'th'}, 'dim_ordering must be in {tf, th}'
        self.dim_ordering = dim_ordering
        self.input_spec = [InputSpec(ndim=4)]

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]

    def get_output_shape_for(self, input_shape):
        if self.dim_ordering == 'th':
            return (input_shape[0],
                    input_shape[1],
                    input_shape[2] - self.cropping[0][0] - self.cropping[0][1],
                    input_shape[3] - self.cropping[1][0] - self.cropping[1][1])
        elif self.dim_ordering == 'tf':
            return (input_shape[0],
                    input_shape[1] - self.cropping[0][0] - self.cropping[0][1],
                    input_shape[2] - self.cropping[1][0] - self.cropping[1][1],
                    input_shape[3])
        else:
            raise Exception('Invalid dim_ordering: ' + self.dim_ordering)

    def call(self, x, mask=None):
        input_shape = self.input_spec[0].shape
        if self.dim_ordering == 'th':
            return x[:,
                     :,
                     self.cropping[0][0]:input_shape[2]-self.cropping[0][1],
                     self.cropping[1][0]:input_shape[3]-self.cropping[1][1]]
        elif self.dim_ordering == 'tf':
            return x[:,
                     self.cropping[0][0]:input_shape[1]-self.cropping[0][1],
                     self.cropping[1][0]:input_shape[2]-self.cropping[1][1],
                     :]

    def get_config(self):
        config = {'cropping': self.cropping}
        base_config = super(Cropping2D, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

 

4.4 Tips

Remember: you need to make the operation of layer differentiable w.r.t the input and trainable weights you set. Look up keras backend use them.

tensor.shape.eval() returns an integer tuple. You would need to print them a lot ?

 

参考:

https://blog.csdn.net/lujiandong1/article/details/54936185

https://keras.io/zh/layers/writing-your-own-keras-layers/

https://keunwoochoi.wordpress.com/2016/11/18/for-beginners-writing-a-custom-keras-layer/