У меня есть модель tf-lite, которая принимает матрицу чисел с плавающей точкой (1, 96, 64, 1). Я хочу загрузить это и сделать прогноз. Я сделал это легко в python, но нужно перенести это на swift. Кто-нибудь знает как это сделать?
Вот мой код python, если он полезен. Спасибо !!
def get_model():
# get model (get_model)
model_path = "./model_vggish_cv2.1757-0.21.tflite"
model = tf.lite.Interpreter(model_path=model_path)
# setup model_idx
model_idx = [0, 0]
model_idx[0] = model.get_input_details()[0]['index']
model_idx[1] = model.get_output_details()[0]['index']
# input size
input_shape = model.get_input_details()[0]['shape']
# adjust model
model.resize_tensor_input(
model_idx[0], [1, input_shape[1], input_shape[2], input_shape[3]])
model.allocate_tensors()
return model, model_idx
model_features = model_features.reshape(
(1, 96, 64, 1)).astype(np.float32)
model, model_idx = get_model()
model.set_tensor(model_idx[0], model_features)
model.invoke()
confidence = model.get_tensor(model_idx[1]).squeeze()