Я пытался заставить пользовательскую операцию работать с версией TensorFlow для Windows.Я следовал руководству:
https://www.tensorflow.org/guide/extend/op#build_a_pip_package_for_your_custom_op
Мне удалось построить правильно, но, кажется, моя пользовательская операция не распознается, когда я пытаюсь выполнить тест на нем.
Кажется, он распознает встроенную пользовательскую операцию tf.user_ops.my_fact
, но не распознает созданный мной user_ops
: squared_out
.
Вот что я делал шаг за шагом:
- Клонирование хранилища:
git clone https://github.com/tensorflow/tensorflow.git
Я поместил мою реализацию C ++
squared_out.cc
в tenorflow \ tenorsflow \ core \ user_ops
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("SquaredOut")
.Input("to_square: int32")
.Output("squared: int32")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext *c) {
c->set_output(0, c->input(0));
return Status::OK();
});
class SquaredOutOp : public OpKernel {
public:
explicit SquaredOutOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
const Tensor& input_tensor = context->input(0);
auto input = input_tensor.flat<int32>();
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor));
auto output_flat = output_tensor->flat<int32>();
const int N = input.size();
for(int i = 0; i < N; ++i){
output_flat(i) = input(i) * input(i);
}
}
};
REGISTER_KERNEL_BUILDER(Name("SquaredOut").Device(DEVICE_CPU), SquaredOutOp);
После этого я настраиваю и собираю библиотеку tenorflow с помощью bazel:
python ./configure.py
bazel build --config=opt //tensorflow/tools/pip_package:build_pip_package
bazel-bin\tensorflow\tools\pip_package\build_pip_package C:/tmp/tensorflow_pkg
Затем я установил .whl, созданный bazel:
pip install C:/tmp/tensorflow_pkg/tensorflow-version-cp35-cp35m-win_amd64.whl
Наконец, я попытался протестировать недавно импортированный
squared_out
из
user_ops
:
import tensorflow as tf
#This works fine with build in fact function
tf.user_ops.fact
Это вернуло:
<function my_fact at 0x0000020CE5601048>
Однако моя пользовательская операция нене работает:
tf.user_ops.squared_out
AttributeError: module 'tensorflow._api.v1.user_ops' has no attribute 'squared_out'