Как реализовать суммирование на керасах? - PullRequest
0 голосов
/ 27 апреля 2020

Я довольно новичок в керасе, но я хочу сравнить суммирование пула со средней производительностью пула. Я попытался написать простой слой суммирования, как этот:

class SumPool2d():
    def __init__(self,k=2):
        super(SumPool2d, self).__init__()
        self.pool = AveragePooling2D(pool_size=(k,k))
        self.kernel_size = k*k

    def forward(self, x):
        return self.kernel_size*self.pool(x)

Вот моя сетевая архитектура:

model = Sequential()

# first layer
model.add(Conv2D(batch_size, (kernel_size, kernel_size), activation='relu', padding='same', input_shape=(img_w, img_h, 3)))
model.add(MaxPooling2D((max_pool, max_pool))) 
# second layer
model.add(Conv2D(2*batch_size, (kernel_size, kernel_size), activation='relu', padding='same'))
model.add(MaxPooling2D((max_pool, max_pool)))
# third layer
model.add(Conv2D(4*batch_size, (kernel_size, kernel_size), activation='relu', padding='same')) 
model.add(MaxPooling2D((max_pool, max_pool))) 
# fourth layer
model.add(Conv2D(4*batch_size, (kernel_size, kernel_size), activation='relu', padding='same')) 
model.add(SumPool2d())
# fifth layer
model.add(Conv2D(4*batch_size, (kernel_size, kernel_size), activation='relu', padding='same'))
model.add(MaxPooling2D((max_pool, max_pool)))     
model.add(Flatten())
model.add(Dense(8*batch_size, activation='relu'))
model.add(Dense(8*batch_size, activation='relu'))
model.add(Dropout(0.2))     
model.add(Dense(8*batch_size, activation='relu'))                                           
model.add(Dense(gems_count, activation='softmax'))

, и я получаю эту ошибку:

Traceback (most recent call last):
  File "Classification_Train.py", line 142, in <module>
    model.add(SumPool2d()) #reduce the spatial size of incoming features
  File "/home/scotch/anaconda3/envs/tf36/lib/python3.7/site-packages/keras/engine/sequential.py", line 133, in add
    'Found: ' + str(layer))
TypeError: The added layer must be an instance of class Layer. Found: <__main__.SumPool2d object at 0x7fcf1e282d50>

Любая помощь будет полезна. Благодаря.

...