Keras自定义或者重写层,需要实现三个方法:
build(input_shape)
这里主要是是定义权重,通过self.build=True
设置哪些参数参与训练,通常通过super([Layer],self).build()
调用父类的build
函数完成call(x)
编写层的功能逻辑的地方,通常只需要关注传入的第一个参数:输入张量,除非你希望你的层支持masking
,这层就是输入张量到输出张量的计算过程。compute_output_shape(input_shape)
,如果你的层更改了输入张量的形状,这层定义输出张量的维度,这让Keras能自动推断各层的形状
问题:
- 1.初看到自定义层都会对
build
的input_shape
参数产生疑问,实际上,我们在输入层会指定输入的维度,在每一层也会返回输出的维度,Keras也会根据计算图自动推断。 - 2.重写layer的时候是否需要考虑batchsize?
Keras的layer是一个Tensor到Tensor的映射,默认batch_size是保持不变,所以我们在Reshape变换维度时也不用传入batch_size维度
参考 keras 自定义层
最后举一个conditional layer normalization的例子
基于Conditional Layer Normalization的条件文本生成
# 自定义层需要实现三个方法
class LayerNormalization(Layer):
"""(Conditional) Layer Normalization
hidden_*系列参数仅为有条件输入时(conditional=True)使用
hidden_units 降维的维度,用于输入的条件矩阵过大,先降维再变换
hidden_activation 一般采用线性激活
"""
def __init__(
self,
center=True,
scale=True,
epsilon=None,
conditional=False,
hidden_units=None,
hidden_activation='linear',
hidden_initializer='glorot_uniform',
**kwargs
):
super(LayerNormalization, self).__init__(**kwargs)
self.center = center
self.scale = scale
self.conditional = conditional
self.hidden_units = hidden_units
self.hidden_activation = activations.get(hidden_activation)
self.hidden_initializer = initializers.get(hidden_initializer)
self.epsilon = epsilon or 1e-12
def build(self, input_shape):
super(LayerNormalization, self).build(input_shape) # self.built=True
if self.conditional:
shape = (input_shape[0][-1],)
else:
shape = (input_shape[-1],)
if self.center:
self.beta = self.add_weight(
shape=shape, initializer='zeros', name='beta'
)
if self.scale:
self.gamma = self.add_weight(
shape=shape, initializer='ones', name='gamma'
)
if self.conditional:
if self.hidden_units is not None:
# 用于降维
self.hidden_dense = Dense(
units=self.hidden_units,
activation=self.hidden_activation,
use_bias=False,
kernel_initializer=self.hidden_initializer
)
if self.center:
self.beta_dense = Dense(
units=shape[0], use_bias=False, kernel_initializer='zeros'
)
if self.scale:
self.gamma_dense = Dense(
units=shape[0], use_bias=False, kernel_initializer='zeros'
)
def call(self, inputs):
"""如果是条件Layer Norm,则默认以list为输入,第二个是condition
"""
if self.conditional:
inputs, cond = inputs
# 用于降维
if self.hidden_units is not None:
cond = self.hidden_dense(cond)
# 扩充维度保证与inputs维度相同
for _ in range(K.ndim(inputs) - K.ndim(cond)):
cond = K.expand_dims(cond, 1)
if self.center:
beta = self.beta_dense(cond) + self.beta
if self.scale:
gamma = self.gamma_dense(cond) + self.gamma
else:
if self.center:
beta = self.beta
if self.scale:
gamma = self.gamma
outputs = inputs
if self.center:
# layer normalization 取一个batch,一列的yi'yang
mean = K.mean(outputs, axis=-1, keepdims=True)
outputs = outputs - mean
if self.scale:
variance = K.mean(K.square(outputs), axis=-1, keepdims=True)
std = K.sqrt(variance + self.epsilon)
outputs = outputs / std
outputs = outputs * gamma
if self.center:
outputs = outputs + beta
return outputs
# input_shape是一个list 定义输出维度
def compute_output_shape(self, input_shape):
if self.conditional:
return input_shape[0]
else:
return input_shape
# 融合当前类和父类的config
def get_config(self):
config = {
'center': self.center,
'scale': self.scale,
'epsilon': self.epsilon,
'conditional': self.conditional,
'hidden_units': self.hidden_units,
'hidden_activation': activations.serialize(self.hidden_activation),
'hidden_initializer':
initializers.serialize(self.hidden_initializer),
}
base_config = super(LayerNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))