tf.device () появилось исключение - PullRequest
0 голосов
/ 29 января 2019

Когда я использую tf.device () для назначения номера GPU, это оказалось исключением.Я впервые задаю вопрос в переполнении стека, если в нем есть какая-то ошибка, пожалуйста, прости меня и скажи мне.

Когда я добавляю allow_soft_placement = True в коде, это работает.

1 Ответ

0 голосов
/ 29 января 2019
def init_graph(self):
    """
    init bert graph
    """
    with self.tf_instance.device('device:GPU:{}'.format(str(self.gpu_no))):
        # add tokenizer
        from bert import tokenization
        self.tokenizer = tokenization.FullTokenizer(self.args.vocab_file)
        from bert import modeling
        bert_config = modeling.BertConfig.from_json_file(self.args.config_file)
        self.model = modeling.BertModel(config=bert_config,
                                        is_training=False,
                                        input_ids=self.input_ids,
                                        input_mask=self.input_mask,
                                        token_type_ids=self.input_type_ids,
                                        use_one_hot_embeddings=False)

        # get output weights and output bias
        reader = self.tf_instance.train.NewCheckpointReader(self.args.ckpt_file)
        output_weights = reader.get_tensor('output_weights')
        output_bias = reader.get_tensor('output_bias')

        # get result op
        output_layer = self.model.get_pooled_output()
        logits = self.tf_instance.matmul(output_layer, output_weights, transpose_b=True)
        logits = self.tf_instance.nn.bias_add(logits, output_bias)
        self.probabilities = self.tf_instance.nn.softmax(logits, axis=-1)

        sess_config = self.tf_instance.ConfigProto()
        sess_config.gpu_options.allow_growth = True

        graph = self.probabilities.graph
        saver = self.tf_instance.train.Saver()
        self.sess = self.tf_instance.Session(config=sess_config, graph=graph)
        self.sess.run(self.tf_instance.global_variables_initializer())
        self.tf_instance.reset_default_graph()
        saver.restore(self.sess, self.args.ckpt_file)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...