Я использую keras для построения сети, однако мне нужно объявить tf-тензор для выполнения некоторых вычислений cuda в методах set_abstraction_msg
и set_abstraction
, а затем повернуть тензор обратно к некоторым видам форм, которые керасируютможет соответствовать и соответствовать.Как я могу это сделать?
Приведенный ниже метод обычно называется getModel
, но вместо этого я назвал его pointnet2
.
Приведенный ниже код в основном предназначен для первого объявления заполнителя тензорного потока,второй - вычисления cuda и применение Conv2D
и BatchNormalization
в методах set_abstraction_msg
и set_abstraction
, третье - некоторые операции Dense
, BatchNormalization
и Dropout
.
def pointnet2(nb_classes):
input_points = tf.placeholder(tf.float32, shape=(16, 1024, 3))
model_input = Input(tensor=input_points)
sa1_xyz, sa1_points = set_abstraction_msg(model_input,
None,
512,
[0.1, 0.2, 0.4],
[16, 32, 128],
[[32, 32, 64], [64, 64, 128], [64, 96, 128]])
sa2_xyz, sa2_points = set_abstraction_msg(sa1_xyz,
sa1_points,
128,
[0.2, 0.4, 0.8],
[32, 64, 128],
[[64, 64, 128], [128, 128, 256], [128, 128, 256]])
sa3_xyz, sa3_points = set_abstraction(sa2_xyz,
sa2_points,
[256, 512, 1024])
# point_net_cls
c = Dense(512, activation='relu')(sa3_points)
c = BatchNormalization()(c)
c = Dropout(0.5)(c)
c = Dense(256, activation='relu')(c)
c = BatchNormalization()(c)
c = Dropout(0.5)(c)
c = Dense(nb_classes, activation='softmax')(c)
prediction = Flatten()(c)
model = Model(inputs=model_input, outputs=prediction)
# turn tf tensor to keras
return model
Iпопробовал Input(tensor=input_points)
.Оказалось, что prediction
- это тензор, который для меня префект.Но я хочу, чтобы тензор обратился к форме keras в конце, и приведенный выше код получил ошибку вроде этой в этой строке Model(inputs=model_input, outputs=prediction)
:
Output tensors to a Model must be the output of a TensorFlow Layer (thus holding past layer metadata). Found: Tensor("flatten/Reshape:0", shape=(16, 40), dtype=float32)
Для получения дополнительной информации полный проект кода находится здесь: https://github.com/HarborZeng/pointnet2-keras