Модель Swift TensorFlow Lite не для классификации изображений - PullRequest
0 голосов
/ 08 апреля 2020

У меня есть модель tf-lite, которая принимает матрицу чисел с плавающей точкой (1, 96, 64, 1). Я хочу загрузить это и сделать прогноз. Я сделал это легко в python, но нужно перенести это на swift. Кто-нибудь знает как это сделать?

Вот мой код python, если он полезен. Спасибо !!

def get_model():
    # get model (get_model)
    model_path = "./model_vggish_cv2.1757-0.21.tflite"
    model = tf.lite.Interpreter(model_path=model_path)

    # setup model_idx
    model_idx = [0, 0]
    model_idx[0] = model.get_input_details()[0]['index']
    model_idx[1] = model.get_output_details()[0]['index']
    # input size
    input_shape = model.get_input_details()[0]['shape']
    # adjust model
    model.resize_tensor_input(
        model_idx[0], [1, input_shape[1], input_shape[2], input_shape[3]])
    model.allocate_tensors()
    return model, model_idx

model_features = model_features.reshape(
                (1, 96, 64, 1)).astype(np.float32)
model, model_idx = get_model()
model.set_tensor(model_idx[0], model_features)
model.invoke()
confidence = model.get_tensor(model_idx[1]).squeeze()
...