Source code for qualia_core.learningmodel.keras.ResNet


[docs] class ResBlock: # Not an actual Keras model because it is harder to serialize and get it to work on third-party tools expansion = 1 def __init__(self, layers_t, in_filters, filters, kernel_size, stride, padding, batch_norm: bool=False, bn_momentum: float=0.99, force_projection_with_stride: bool=True): super().__init__() from tensorflow.keras.layers import BatchNormalization, Activation, Add from tensorflow.python.keras.backend import unique_object_name self.in_filters = in_filters self.filters = filters self.stride = stride self.batch_norm = batch_norm self.force_projection_with_stride = force_projection_with_stride self.padding_1 = layers_t.ZeroPadding(padding=padding) self.conv_1 = layers_t.Conv(filters=filters, kernel_size=kernel_size, strides=1, use_bias=not batch_norm) if batch_norm: self.batchnorm_1 = BatchNormalization(momentum=bn_momentum) if self.stride != 1: self.maxpool_1 = layers_t.MaxPooling(stride) self.relu_1 = Activation('relu') self.padding_2 = layers_t.ZeroPadding(padding=padding) self.conv_2 = layers_t.Conv(name=unique_object_name('conv_ref', zero_based=True), filters=filters, kernel_size=kernel_size, strides=1, use_bias=not batch_norm) if batch_norm: self.batchnorm_2 = BatchNormalization(momentum=bn_momentum) if self.stride != 1: self.maxpool_shortcut = layers_t.MaxPooling(stride) if self.in_filters != self.expansion*self.filters or force_projection_with_stride and self.stride != 1: self.conv_shortcut = layers_t.Conv(name=unique_object_name('conv_shortcut', zero_based=True), filters=self.expansion * filters, kernel_size=1, strides=1, use_bias=not batch_norm) if batch_norm: self.batchnorm_shortcut = BatchNormalization(momentum=bn_momentum) self.add = Add() self.relu_2 = Activation('relu') def __call__(self, inputs): x = self.padding_1(inputs) x = self.conv_1(x) if self.batch_norm: x = self.batchnorm_1(x) if self.stride != 1: x = self.maxpool_1(x) x = self.relu_1(x) x = self.padding_2(x) x = self.conv_2(x) if self.batch_norm: x = self.batchnorm_2(x) # shortcut tmp = inputs if self.in_filters != self.expansion*self.filters or self.force_projection_with_stride and self.stride != 1: tmp = self.conv_shortcut(tmp) if self.batch_norm: tmp = self.batchnorm_shortcut(tmp) if self.stride != 1: tmp = self.maxpool_shortcut(tmp) x = self.add([x, tmp]) x = self.relu_2(x) return x @property def output_shape(self): return self.relu_2.output_shape
[docs] def ResNet(input_shape, output_shape, filters: int=(15, 15, 30, 60, 120), # + (120,) kernel_sizes: tuple=(3, 3, 3, 3, 3), num_blocks: tuple=(2, 2, 2, 2), # + (2,) strides: tuple=(1, 1, 2, 2, 2), # + (2,) paddings: int=(1, 1, 1, 1, 1), prepool: int=1, batch_norm: bool=False, bn_momentum: float=0.99, force_projection_with_stride: bool=True, dims: int=1): import tensorflow as tf from tensorflow.keras import Model from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Input, Flatten, Dense, BatchNormalization, Activation import tensorflow.keras.backend as backend if dims == 1: import qualia_core.learningmodel.keras.layers1d as layers_t elif dims == 2: import qualia_core.learningmodel.keras.layers2d as layers_t else: raise ValueError('Only dims=1 or dims=2 supported') # Do not track moving statistics (avg/var) for inference, use batch statistics only like PyTorch's track_running_stats track_running_stats ={} # {'training': True} # Actual model definition # Conv → BatchNorm → ResBlock → ResBlock → ResBlock → AveragePooling → Flatten → FullyConnected(out) inputs = Input(shape=input_shape) if prepool > 1: x = layers_t.AveragePooling(pool_size=prepool)(inputs) else: x = inputs # First non-residual convolution x = layers_t.ZeroPadding(padding=paddings[0])(x) x = layers_t.Conv(filters=filters[0], kernel_size=kernel_sizes[0], strides=strides[0], use_bias=not batch_norm)(x) if batch_norm: x = BatchNormalization(momentum=bn_momentum)(x) x = Activation('relu')(x) # Residual layers in_filters = filters[0] for f, kernel_size, layer_stride, padding, num_blocks in zip(filters[1:], kernel_sizes[1:], strides[1:], paddings[1:], num_blocks): for block_stride in [layer_stride] + [1] * (num_blocks - 1): block = ResBlock(layers_t=layers_t, in_filters=in_filters, filters=f, kernel_size=kernel_size, stride=block_stride, padding=padding, batch_norm=batch_norm, bn_momentum=bn_momentum) # 1st block of layer x = block(x) in_filters = f * block.expansion # Reduce last conv layer output feature maps x = layers_t.MaxPooling(x.shape[1])(x) x = Flatten()(x) x = Dense(output_shape[0], use_bias=True)(x) outputs = Activation('softmax')(x) model = Model(inputs=inputs, outputs=outputs, name='ResNet') return model