Я пытаюсь создать ResNet34 Encoder как часть моего CNN со следующей функцией на Python 3.7.
import tensorflow as tf
from tensorpack import *
from tensorpack.models import BatchNorm, BNReLU, Conv2D, MaxPooling, FixedUnPooling
from tensorpack.tfutils.summary import add_moving_summary, add_param_summary
from .utils import *
import sys
sys.path.append("..") # adds higher directory to python modules path.
try: # HACK: import beyond current level, may need to restructure
from config import Config
except ImportError:
assert False, 'Fail to import config.py'
def res_blk(name, l, ch, ksize, count, split=1, strides=1, freeze=False):
ch_in = l.get_shape().as_list()
with tf.variable_scope(name):
for i in range(0, count):
with tf.variable_scope('block' + str(i)):
x = l if i == 0 else BNReLU('preact', l)
x = Conv2D('conv1', x, ch[0], ksize[0], activation=BNReLU)
x = Conv2D('conv2', x, ch[1], ksize[1], split=split,
strides=strides if i == 0 else 1, activation=BNReLU)
x = Conv2D('conv3', x, ch[2], ksize[2], activation=tf.identity)
if (strides != 1 or ch_in[1] != ch[2]) and i == 0:
l = Conv2D('convshortcut', l, ch[2], 1, strides=strides)
x = tf.stop_gradient(x) if freeze else x
l = l + x
l = BNReLU('bnlast',l)
return l
def encoder(i, freeze):
d1 = Conv2D('conv0', i, 64, 7, padding='valid', strides=1, activation=BNReLU)
d1 = res_blk('group0', d1, [ 64, 64], [3, 3], 3, strides=1, freeze=freeze)
d2 = res_blk('group1', d1, [128, 128], [3, 3], 4, strides=2, freeze=freeze)
d2 = tf.stop_gradient(d2) if freeze else d2
d3 = res_blk('group2', d2, [256, 256], [3, 3], 6, strides=2, freeze=freeze)
d3 = tf.stop_gradient(d3) if freeze else d3
d4 = res_blk('group3', d3, [512, 512], [3, 3], 3, strides=2, freeze=freeze)
d4 = tf.stop_gradient(d4) if freeze else d4
d4 = Conv2D('conv_bot', d4, 1024, 1, padding='same')
return [d1, d2, d3, d4]
Затем я получаю сообщение об ошибке
line 67, in encoder
d1 = res_blk('group0', d1, [ 64, 64], [3, 3], 3, strides=1, freeze=freeze)
File "....", line 34, in res_blk
x = Conv2D('conv3', x, ch[2], ksize[2], activation=tf.identity)
IndexError: list index out of range
В чем причина этой ошибки и как ее исправить? Исходным кодом был Resnet50, который работал нормально, т.е. код будет
d1 = res_blk('group0', d1, [ 64, 64, 256], [1, 3, 1], 3, strides=1, freeze=freeze)