Оказывается, в этом нет ничего нового, если вы уже сделали это на Python. Предполагая, что модель была названа «прогнозировать», а входные данные для модели называются «входными данными», следующий код Python:
import logging
import grpc
from grpc import RpcError
from types_pb2 import DT_FLOAT
from tensor_pb2 import TensorProto
from tensor_shape_pb2 import TensorShapeProto
from predict_pb2 import PredictRequest
from prediction_service_pb2_grpc import PredictionServiceStub
class ModelClient:
"""Client Facade to work with a Tensorflow Serving gRPC API"""
host = None
port = None
chan = None
stub = None
logger = logging.getLogger(__name__)
def __init__(self, name, dims, dtype=DT_FLOAT, version=1):
self.model = name
self.dims = [TensorShapeProto.Dim(size=dim) for dim in dims]
self.dtype = dtype
self.version = version
@property
def hostport(self):
"""A host:port string representation"""
return f"{self.host}:{self.port}"
def connect(self, host='localhost', port=8500):
"""Connect to the gRPC server and initialize prediction stub"""
self.host = host
self.port = int(port)
self.logger.info(f"Connecting to {self.hostport}...")
self.chan = grpc.insecure_channel(self.hostport)
self.logger.info("Initializing prediction gRPC stub.")
self.stub = PredictionServiceStub(self.chan)
def tensor_proto_from_measurement(self, measurement):
"""Pass in a measurement and return a tensor_proto protobuf object"""
self.logger.info("Assembling measurement tensor.")
return TensorProto(
dtype=self.dtype,
tensor_shape=TensorShapeProto(dim=self.dims),
string_val=[bytes(measurement)]
)
def predict(self, measurement, timeout=10):
"""Execute prediction against TF Serving service"""
if self.host is None or self.port is None \
or self.chan is None or self.stub is None:
self.connect()
self.logger.info("Creating request.")
request = PredictRequest()
request.model_spec.name = self.model
if self.version > 0:
request.model_spec.version.value = self.version
request.inputs['inputs'].CopyFrom(
self.tensor_proto_from_measurement(measurement))
self.logger.info("Attempting to predict against TF Serving API.")
try:
return self.stub.Predict(request, timeout=timeout)
except RpcError as err:
self.logger.error(err)
self.logger.error('Predict failed.')
return None
Ниже приведен рабочий (грубый) перевод C ++:
#include <iostream>
#include <memory>
#include <string>
#include <grpcpp/grpcpp.h>
#include "grpcpp/create_channel.h"
#include "grpcpp/security/credentials.h"
#include "google/protobuf/map.h"
#include "types.grpc.pb.h"
#include "tensor.grpc.pb.h"
#include "tensor_shape.grpc.pb.h"
#include "predict.grpc.pb.h"
#include "prediction_service.grpc.pb.h"
using grpc::Channel;
using grpc::ClientContext;
using grpc::Status;
using tensorflow::TensorProto;
using tensorflow::TensorShapeProto;
using tensorflow::serving::PredictRequest;
using tensorflow::serving::PredictResponse;
using tensorflow::serving::PredictionService;
typedef google::protobuf::Map<std::string, tensorflow::TensorProto> OutMap;
class ServingClient {
public:
ServingClient(std::shared_ptr<Channel> channel)
: stub_(PredictionService::NewStub(channel)) {}
// Assembles the client's payload, sends it and presents the response back
// from the server.
std::string callPredict(const std::string& model_name,
const float& measurement) {
// Data we are sending to the server.
PredictRequest request;
request.mutable_model_spec()->set_name(model_name);
// Container for the data we expect from the server.
PredictResponse response;
// Context for the client. It could be used to convey extra information to
// the server and/or tweak certain RPC behaviors.
ClientContext context;
google::protobuf::Map<std::string, tensorflow::TensorProto>& inputs =
*request.mutable_inputs();
tensorflow::TensorProto proto;
proto.set_dtype(tensorflow::DataType::DT_FLOAT);
proto.add_float_val(measurement);
proto.mutable_tensor_shape()->add_dim()->set_size(5);
proto.mutable_tensor_shape()->add_dim()->set_size(8);
proto.mutable_tensor_shape()->add_dim()->set_size(105);
inputs["inputs"] = proto;
// The actual RPC.
Status status = stub_->Predict(&context, request, &response);
// Act upon its status.
if (status.ok()) {
std::cout << "call predict ok" << std::endl;
std::cout << "outputs size is " << response.outputs_size() << std::endl;
OutMap& map_outputs = *response.mutable_outputs();
OutMap::iterator iter;
int output_index = 0;
for (iter = map_outputs.begin(); iter != map_outputs.end(); ++iter) {
tensorflow::TensorProto& result_tensor_proto = iter->second;
std::string section = iter->first;
std::cout << std::endl << section << ":" << std::endl;
if ("classes" == section) {
int titer;
for (titer = 0; titer != result_tensor_proto.int64_val_size(); ++titer) {
std::cout << result_tensor_proto.int64_val(titer) << ", ";
}
} else if ("scores" == section) {
int titer;
for (titer = 0; titer != result_tensor_proto.float_val_size(); ++titer) {
std::cout << result_tensor_proto.float_val(titer) << ", ";
}
}
std::cout << std::endl;
++output_index;
}
return "Done.";
} else {
std::cout << "gRPC call return code: " << status.error_code() << ": "
<< status.error_message() << std::endl;
return "RPC failed";
}
}
private:
std::unique_ptr<PredictionService::Stub> stub_;
};
Обратите внимание, что размеры здесь указаны в коде, а не переданы.
Учитывая приведенный выше класс, выполнение может быть следующим:
int main(int argc, char** argv) {
float measurement[5*8*105] = { ... data ... };
ServingClient sclient(grpc::CreateChannel(
"localhost:8500", grpc::InsecureChannelCredentials()));
std::string model("predict");
std::string reply = sclient.callPredict(model, *measurement);
std::cout << "Predict received: " << reply << std::endl;
return 0;
}
Используемый Makefile
заимствован из примеров gRPC
C ++, с переменной PROTOS_PATH
, установленной относительно Makefile, и следующей целью сборки (при условии, что приложение C ++ названо predict.cc
):
predict: types.pb.o types.grpc.pb.o tensor_shape.pb.o tensor_shape.grpc.pb.o resource_handle.pb.o resource_handle.grpc.pb.o model.pb.o model.grpc.pb.o tensor.pb.o tensor.grpc.pb.o predict.pb.o predict.grpc.pb.o prediction_service.pb.o prediction_service.grpc.pb.o predict.o
$(CXX) $^ $(LDFLAGS) -o $@