BERT получить вложение уровня предложения после тонкой настройки - PullRequest
4 голосов
/ 20 марта 2020

Я наткнулся на эту страницу

1) Я хотел бы получить вложение уровня предложения (встраивание, данное токеном [CLS]) после выполнения тонкой настройки. Как я мог это сделать?

2) Я также заметил, что код на этой странице занимает много времени, чтобы вернуть результаты на тестовых данных. Это почему? Когда я обучал модель, это занимало меньше времени по сравнению с тем, когда я пытался получить тестовые прогнозы. Из кода на этой странице я не использовал нижеприведенные блоки кода

test_InputExamples = test.apply(lambda x: bert.run_classifier.InputExample(guid=None, 
                                                                       text_a = x[DATA_COLUMN], 
                                                                       text_b = None, 
                                                                       label = x[LABEL_COLUMN]), axis = 1

test_features = bert.run_classifier.convert_examples_to_features(test_InputExamples, label_list, MAX_SEQ_LENGTH, tokenizer)

test_input_fn = run_classifier.input_fn_builder(
        features=test_features,
        seq_length=MAX_SEQ_LENGTH,
        is_training=False,
        drop_remainder=False)

estimator.evaluate(input_fn=test_input_fn, steps=None)

Скорее я просто использовал функцию ниже для всех моих тестовых данных

def getPrediction(in_sentences):
  labels = ["Negative", "Positive"]
  input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
  input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
  predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
  predictions = estimator.predict(predict_input_fn)
  return [(sentence, prediction['probabilities'], labels[prediction['labels']]) for sentence, prediction in zip(in_sentences, predictions)]

3) как я могу получить вероятность предсказания. Есть ли способ использовать keras predict метод?

update1

вопрос 2 обновление - не могли бы вы протестировать 20000 примеров обучения с использованием функции getPrediction? .... это занимает гораздо больше времени для меня ... даже больше, чем потрачено на обучение модели на 20000 примерах.

Ответы [ 2 ]

3 голосов
/ 26 марта 2020

Конечно, вот остальные изменения:

# model_fn_builder actually creates our model function
# using the passed parameters for num_labels, learning_rate, etc.
def model_fn_builder(num_labels, learning_rate, num_train_steps,
                     num_warmup_steps):
  """Returns `model_fn` closure for TPUEstimator."""
  def model_fn(features, labels, mode, params):  # pylint: disable=unused-argument
    """The `model_fn` for TPUEstimator."""

    input_ids = features["input_ids"]
    input_mask = features["input_mask"]
    segment_ids = features["segment_ids"]
    label_ids = features["label_ids"]

    is_predicting = (mode == tf.estimator.ModeKeys.PREDICT)

    # TRAIN and EVAL
    if not is_predicting:

      (loss, predicted_labels, log_probs, probs, pooled_output) = create_model(
        is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

      train_op = bert.optimization.create_optimizer(
          loss, learning_rate, num_train_steps, num_warmup_steps, use_tpu=False)

      # Calculate evaluation metrics. 
      def metric_fn(label_ids, predicted_labels):
        accuracy = tf.metrics.accuracy(label_ids, predicted_labels)
        f1_score = tf.contrib.metrics.f1_score(
            label_ids,
            predicted_labels)
        auc = tf.metrics.auc(
            label_ids,
            predicted_labels)
        recall = tf.metrics.recall(
            label_ids,
            predicted_labels)
        precision = tf.metrics.precision(
            label_ids,
            predicted_labels) 
        true_pos = tf.metrics.true_positives(
            label_ids,
            predicted_labels)
        true_neg = tf.metrics.true_negatives(
            label_ids,
            predicted_labels)   
        false_pos = tf.metrics.false_positives(
            label_ids,
            predicted_labels)  
        false_neg = tf.metrics.false_negatives(
            label_ids,
            predicted_labels)
        return {
            "eval_accuracy": accuracy,
            "f1_score": f1_score,
            "auc": auc,
            "precision": precision,
            "recall": recall,
            "true_positives": true_pos,
            "true_negatives": true_neg,
            "false_positives": false_pos,
            "false_negatives": false_neg
        }

      eval_metrics = metric_fn(label_ids, predicted_labels)

      if mode == tf.estimator.ModeKeys.TRAIN:
        return tf.estimator.EstimatorSpec(mode=mode,
          loss=loss,
          train_op=train_op)
      else:
          return tf.estimator.EstimatorSpec(mode=mode,
            loss=loss,
            eval_metric_ops=eval_metrics)
    else:
      (predicted_labels, log_probs, probs, pooled_output) = create_model(
        is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

      predictions = {
          'log_probabilities': log_probs,
          'probabilities': probs,
          'labels': predicted_labels,
          'pooled_output': pooled_output
      }
      return tf.estimator.EstimatorSpec(mode, predictions=predictions)

  # Return the actual model function in the closure
  return model_fn


def getPrediction(in_sentences):
  labels = ["Negative", "Positive"]
  input_examples = [run_classifier.InputExample(guid="", text_a = x, text_b = None, label = 0) for x in in_sentences] # here, "" is just a dummy label
  input_features = run_classifier.convert_examples_to_features(input_examples, label_list, MAX_SEQ_LENGTH, tokenizer)
  predict_input_fn = run_classifier.input_fn_builder(features=input_features, seq_length=MAX_SEQ_LENGTH, is_training=False, drop_remainder=False)
  predictions = estimator.predict(predict_input_fn)
  return [(sentence, prediction['probabilities'], prediction['log_probabilities'], labels[prediction['labels']], prediction['pooled_output']) for sentence, prediction in zip(in_sentences, predictions)]

и первый вывод (другие обрезаны b c 30K символов ограничения на ответ):

[('That movie was absolutely awful',
  array([0.99599314, 0.00400678], dtype=float32),
  array([-4.0148855e-03, -5.5197663e+00], dtype=float32),
  'Negative',
  array([ 0.9181199 ,  0.7763732 ,  0.9999883 , -0.93533266, -0.9841384 ,
          0.78126144, -0.9918988 , -0.18764131,  0.9981035 ,  0.99999994,
          0.900716  , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
          0.9501321 ,  0.75836045,  0.49151263, -0.7886792 ,  0.97505844,
         -0.8931161 , -1.        ,  0.9318583 , -0.60531116, -0.8644371 ,
         -0.9999866 ,  0.5820049 ,  0.3257555 , -0.81900954, -0.8326617 ,
          0.87788117, -0.7791749 ,  0.11098853,  0.67873836,  0.9999771 ,
          0.9833652 , -0.8420576 ,  0.83076835,  0.37272754,  0.8667175 ,
          0.792386  , -0.82003427, -0.9999999 , -0.9382297 , -0.9713775 ,
          0.55752313,  1.        , -0.72632766, -0.4752956 , -0.9999852 ,
         -0.99974227, -0.9998661 , -0.3094257 , -0.93023825, -0.72663504,
          0.92974335, -0.8601105 , -0.8113003 ,  0.7660112 ,  0.9313508 ,
          0.21427669, -0.45660907,  0.99970686,  0.56852764, -0.9997675 ,
         -0.9999096 ,  0.8247045 ,  0.7205424 ,  0.47192624, -0.7523966 ,
         -0.9588541 , -0.48866934,  0.9809366 , -0.07110611, -0.99886   ,
         -0.63922834, -0.68144   , -1.        ,  0.8531816 ,  0.26078308,
         -0.99898577, -0.99968046,  0.6711601 ,  0.99857473, -0.99990964,
          1.        , -0.97127694, -0.10644457,  0.46306637, -0.32486317,
         -0.68167734,  0.43291137, -0.996574  ,  0.05164305,  0.9897354 ,
          0.93853104,  0.94800174,  0.9995697 ,  0.6532897 ,  0.93846226,
         -0.6281378 ,  0.5574107 ,  0.725278  ,  0.74160355, -0.6486919 ,
          0.88869256,  0.9439776 , -0.9654787 , -0.95139974, -0.9366148 ,
          0.17409436,  0.83473635, -0.87414986, -0.35965624, -0.8395183 ,
          0.5546853 ,  0.7452196 , -0.6152899 , -0.82187194, -0.65487677,
          0.94367695,  0.6834396 , -0.72266734,  0.99376386, -0.76821744,
          0.4485644 ,  0.99982166,  1.        ,  0.9260674 ,  0.9759094 ,
          0.9397613 ,  0.8128903 , -0.7918152 ,  0.30299878, -0.95160294,
          0.25385544, -0.57780135, -0.9999994 ,  0.9168113 , -0.36585295,
          0.9798102 ,  0.95976156, -0.99428   ,  0.6471789 , -0.9948078 ,
         -0.9686591 ,  0.93615085, -0.11481134,  0.87566274, -0.91601896,
          0.9952683 ,  0.26532048,  0.99861896,  0.79298306,  0.5872364 ,
         -0.56314534,  0.96794534,  0.9999797 ,  0.9879324 ,  0.5003342 ,
          0.9516269 , -0.8878316 , -0.9665091 , -0.88037425,  0.8356687 ,
         -0.71543014, -0.99985015, -0.9414574 ,  0.8681497 ,  0.950698  ,
         -0.8007153 ,  0.78748596,  0.9999305 ,  0.40210736,  0.4856055 ,
         -0.9390776 ,  0.63564163, -0.85989815, -0.8421344 , -0.99436   ,
          0.78081733, -0.97038007,  0.39290914,  0.7834218 ,  0.88715357,
         -0.03653741,  0.99126273, -0.96559966,  0.11924513, -0.99363935,
         -0.9901692 ,  0.963858  ,  0.5713922 ,  0.5676979 ,  0.69982123,
          0.858003  ,  0.9983819 , -0.87965024,  0.46213093, -0.3256273 ,
          0.77337253,  0.7246244 , -0.99894017, -0.9170495 , -0.98803675,
         -0.93148243,  0.09674019,  0.09448949, -0.7453027 , -0.78955775,
         -0.6304773 , -0.5597632 ,  0.992308  ,  0.7769483 ,  0.04146893,
         -0.15876745, -0.7682887 , -0.5231416 ,  0.7871302 ,  0.9503481 ,
         -0.9607153 ,  0.99047405, -0.9948017 , -0.82257754,  0.9990552 ,
          0.79346406, -0.78624016,  0.8760266 , -0.7855991 ,  0.13444276,
         -0.7183107 , -0.9999819 ,  0.7019429 , -0.918913  , -0.6569654 ,
          0.9998794 , -0.33805153, -0.9427715 ,  0.10419375, -0.94257164,
          0.9187495 , -0.9994855 , -0.99979955, -0.9277688 ,  0.6353426 ,
          0.9994905 ,  0.90688777,  0.9992008 ,  0.7817533 , -0.9996674 ,
         -0.999962  , -0.13310781, -0.82505953,  0.9997485 ,  0.82616794,
         -0.999998  ,  0.45386457,  0.6069964 ,  0.52272975,  0.8811922 ,
          0.52668494, -0.9994814 , -0.21601789, -0.99882716,  0.90246916,
          0.94196504,  0.30058604, -0.9876776 , -0.7699927 , -0.9980288 ,
          0.7727592 ,  0.9936947 ,  0.98021245, -0.77723926, -0.785372  ,
          0.5150317 ,  0.9983137 , -0.7461883 ,  0.3311537 , -0.63709795,
         -0.6487831 , -0.9173727 ,  0.9997706 , -0.9999893 , -1.        ,
          0.60389155, -0.6516268 , -0.95422006,  1.        ,  0.09109057,
         -0.99999994,  0.99998957,  1.        , -0.19451752,  0.94624877,
         -0.2761865 ,  1.        ,  0.52399474,  0.70230734,  0.5218801 ,
         -0.99716544, -0.70075685, -0.99992603,  1.        , -0.9785006 ,
          0.22457084, -0.5356722 , -0.9991887 ,  0.7062409 ,  0.66816545,
         -0.90308225, -0.8084922 ,  0.50301254, -0.7062079 ,  0.9998321 ,
          0.9823206 ,  0.9984027 ,  0.9948857 , -1.        , -0.7067878 ,
          0.975454  ,  0.87161005, -0.9882297 ,  0.8296374 , -0.88615334,
          0.4316883 ,  0.86287475, -0.9893329 , -0.9022001 , -0.68322754,
         -0.84212875,  0.78632677, -0.5131366 , -0.996949  , -0.75479275,
         -0.06342169,  0.92238575,  0.66769385,  0.9926053 , -0.78391105,
          0.9976865 ,  0.07086544,  0.34079495,  0.69730175, -0.99970955,
         -1.        , -0.9860551 ,  0.89584446, -0.96889114, -0.90435815,
          0.944296  , -1.        , -0.9931756 , -0.7014334 , -0.6742562 ,
         -0.96786517,  0.848328  ,  0.8903087 , -0.9998633 ,  0.73993397,
          0.99345684,  0.9691821 ,  0.87563246, -0.6073146 , -0.9999999 ,
          0.90763575,  0.30225936, -0.47824544,  0.7179979 ,  0.9450465 ,
          0.9715953 , -0.5422173 ,  0.99995065, -0.5920663 ,  0.92390317,
         -0.9670669 , -0.3623574 ,  0.74825   , -0.7817521 ,  0.9888685 ,
         -0.7653631 , -0.8933355 ,  0.9481424 ,  0.97803396, -0.9999731 ,
         -0.89597356,  0.35502487, -0.7190486 ,  0.30777818,  0.55025375,
          0.6365793 , -0.99094397, -1.        ,  0.93482614, -0.99970514,
          0.98721176,  0.14699097, -0.86038756, -0.68365514, -0.8104672 ,
          0.57238674,  0.97475344, -0.9963499 ,  0.98476464,  0.40495875,
         -0.7001948 , -0.40898973,  0.61900675, -1.        , -0.9371812 ,
         -0.62749994, -0.8841316 , -0.9999847 , -0.39386114, -0.925245  ,
         -0.99991447, -0.5872595 ,  0.5835767 ,  0.7003338 , -0.9761974 ,
          0.99995846,  0.33676207,  0.9079994 , -0.76412004, -0.7648706 ,
          0.68863285,  0.43983305,  0.74911463, -0.99995685, -0.6692586 ,
         -0.45761266, -0.9980771 , -1.        ,  0.31244457, -0.8834693 ,
          0.9388263 , -0.987405  ,  1.        ,  0.9512058 ,  0.23448633,
          0.37940192,  0.99989796,  0.8402514 , -0.84526414,  0.7378776 ,
         -0.9996204 , -0.99434114,  0.9987527 ,  0.5569713 ,  0.99648696,
         -0.9933159 , -0.13116199,  0.9999992 ,  0.9642579 , -0.48285434,
         -0.97517425,  0.7185596 ,  0.5286405 ,  0.9902838 ,  0.7796022 ,
         -0.80703837,  0.2376029 ,  0.534117  , -0.9999413 ,  0.99828076,
          0.9998345 ,  0.93249476,  0.3620626 ,  0.7567034 , -0.9222681 ,
          0.97832036,  0.9999682 ,  0.6433209 , -1.        ,  0.9268615 ,
         -0.9999511 , -0.9145363 , -0.9213852 ,  0.7606066 , -0.5501025 ,
         -0.99999434, -0.7783993 ,  0.9999771 ,  0.99980384,  0.987094  ,
          0.7531475 , -0.8551696 , -0.9973968 , -0.9999853 , -0.08913276,
         -0.9919206 , -0.49190572,  0.70230234, -0.31277484, -0.99999964,
          0.828591  ,  0.6363776 ,  0.86796165,  0.81575817,  0.7782955 ,
          0.9436437 , -1.        , -0.7509046 , -0.9946139 , -0.6647415 ,
          0.999543  ,  0.9312092 , -1.        ,  0.5639159 ,  0.9482462 ,
         -0.9289936 , -0.9678435 ,  0.60937124, -0.987818  ,  0.5511619 ,
          0.75886583, -0.48466644, -0.71833754,  0.8042149 ,  0.9154103 ,
         -0.8177468 ,  0.7195895 , -0.82283056,  0.24990956, -1.        ,
          0.7729634 ,  0.84048635,  0.7989596 ,  0.9469012 , -0.9898951 ,
         -0.92565274,  0.74726975,  0.78213847, -0.672894  , -0.58831286,
         -0.8039038 , -0.72197783,  0.5289216 , -0.9998796 , -0.9904479 ,
          0.9996592 , -0.28984115,  0.23964961, -0.7427149 , -0.662416  ,
         -1.        , -0.5538268 , -0.9945287 , -0.63471127,  0.5896127 ,
         -0.48429146,  0.9976076 , -0.94329506, -0.49143887,  0.7695602 ,
          0.8638134 , -0.82130384,  0.50105464,  0.9336961 , -0.24716294,
         -0.6922282 , -0.02228704,  0.75649065,  0.82303154, -0.30867255,
         -0.9602714 ,  0.64568967,  0.314201  , -0.4811752 ,  0.27952817,
          0.9227022 ,  0.88095886,  0.89470226,  1.        , -0.19237158,
          1.        , -0.991253  , -0.9991121 ,  0.5637482 , -0.75780976,
         -0.3904836 , -0.9881965 , -0.2912058 ,  0.9998215 ,  0.9869475 ,
         -0.12784953,  0.81566185,  0.9787118 , -0.17835459, -0.7027824 ,
          0.72269535, -0.18194303,  0.9968796 ,  0.03490257,  0.7751488 ,
         -1.        , -0.7761089 ,  0.85105944,  0.9968074 , -0.8156342 ,
          0.5300792 , -1.        ,  0.99626255, -0.7515625 , -0.6672005 ,
          0.9792111 ,  0.8660997 , -0.69161206,  0.32184905,  0.9071073 ,
          0.9999385 , -0.82744277, -0.99044186, -0.71309817, -0.5004305 ,
          0.70707524,  0.89751345, -0.6819585 , -0.9999414 , -0.45255637,
         -0.94375473, -0.91838425,  0.64272994,  0.9375524 ,  0.6609169 ,
         -0.88743365, -0.9534722 , -0.47888806, -1.        , -0.5251781 ,
          0.8274516 ,  0.9326824 ,  0.8961964 ,  0.5295862 ,  0.43714878,
         -0.7488347 , -0.75295556, -0.5187054 ,  0.75924635, -0.7862662 ,
          0.99981725, -0.80290836,  0.97651815,  0.99763787, -0.29619345,
         -0.1252967 ,  0.33606276, -0.65137684, -0.9680231 ,  0.77586985,
          0.22347753,  0.27245504, -0.07826214, -0.8383849 , -0.85373163,
          1.        , -0.4563588 , -0.91339815, -0.9999861 ,  0.66063935,
         -0.985843  , -0.7818757 , -0.7000497 , -0.6840764 ,  0.9995542 ,
          0.60819125,  0.80064404, -0.9776968 , -0.90925264, -0.6644932 ,
         -0.8771755 ,  0.71411085,  0.8113569 ,  0.9974196 , -0.75211936,
          0.63400257, -0.8272833 ,  0.99780786,  0.9965285 ,  0.59551436,
         -0.9876875 , -0.04439292,  0.9939223 ,  0.9993717 , -0.9965501 ,
         -0.9630328 , -0.9027949 , -0.48490363, -0.60193753, -0.6870232 ,
         -0.95355797, -0.67561924,  0.9997761 , -0.85473967,  0.998495  ,
         -0.95756954,  0.633171  ,  0.4570475 , -0.5316367 , -0.9663824 ,
          0.9567106 , -0.45497724,  0.12964879,  0.9964744 , -0.9711668 ,
          0.69636106, -0.9178346 ,  0.8313186 ,  0.69686604,  0.8141587 ,
         -0.33600506,  0.94798595,  0.8800869 ,  0.15029034, -0.91185665,
          0.6322724 , -0.9971475 ,  0.71948224,  0.9695236 ,  0.84242374,
          0.99995124,  0.5982563 , -0.98341423,  0.61301434,  0.9997318 ,
         -0.9981808 , -0.65651804, -0.8484874 , -0.9961815 ,  0.9030814 ,
          0.87141925,  0.8897381 , -0.92870414,  0.07134341,  0.8739935 ,
          0.91630197, -0.9465984 , -0.59741104, -1.        ,  0.9989559 ,
          0.99991184,  0.67439264,  0.92025673, -0.60730827,  0.8362061 ,
          1.        , -0.70801497,  0.9883806 , -0.9984141 ,  0.9919259 ,
         -0.998869  ,  0.9976203 ,  0.9888036 ,  0.8556838 , -0.9722744 ,
         -0.99810714,  0.8182833 ,  0.98808485,  0.6643728 ,  0.99212515,
         -0.99988   ,  0.26405996,  0.93139845,  0.99021816,  0.6846886 ,
          0.9986462 ,  0.92254627, -0.6406982 ], dtype=float32)),
 ('The acting was a bit lacking',
  array([0.9921152 , 0.00788479], dtype=float32),
  array([-0.00791603, -4.842819  ], dtype=float32),
  'Negative',
  array([ 0.67417824,  0.8235167 ,  0.99999565, -0.8565971 , -0.99499583,
          0.8219966 , -0.9185583 , -0.5234593 ,  0.99962074,  0.99999714,
          0.9507927 , -0.9996754 ,  0.22211392, -0.99826247,  0.7562492 ,
          0.93803996,  0.82738185,  0.4773049 , -0.73478544,  0.85207295,
3 голосов
/ 25 марта 2020

1) Из документации BERT

Выходной словарь содержит:

pooled_output: объединенный вывод всей последовательности с формой [batch_size, hidden_size]. sequence_output: представления каждого токена во входной последовательности в форме [batch_size, max_sequence_length, hidden_size].

Я добавил pooled_output вектор, который соответствует вектору CLS.

3) Вы получаете лог вероятности. Просто примените softmax, чтобы получить нормальные вероятности.

Теперь все, что осталось сделать - это сообщить модели об этом. Я оставил пробные журналы, но они больше не нужны.

Смотрите изменения кода:

def create_model(is_predicting, input_ids, input_mask, segment_ids, labels,
                 num_labels):
  """Creates a classification model."""

  bert_module = hub.Module(
      BERT_MODEL_HUB,
      trainable=True)
  bert_inputs = dict(
      input_ids=input_ids,
      input_mask=input_mask,
      segment_ids=segment_ids)
  bert_outputs = bert_module(
      inputs=bert_inputs,
      signature="tokens",
      as_dict=True)

  # Use "pooled_output" for classification tasks on an entire sentence.
  # Use "sequence_outputs" for token-level output.
  output_layer = bert_outputs["pooled_output"]

  pooled_output = output_layer

  hidden_size = output_layer.shape[-1].value

  # Create our own layer to tune for politeness data.
  output_weights = tf.get_variable(
      "output_weights", [num_labels, hidden_size],
      initializer=tf.truncated_normal_initializer(stddev=0.02))

  output_bias = tf.get_variable(
      "output_bias", [num_labels], initializer=tf.zeros_initializer())

  with tf.variable_scope("loss"):

    # Dropout helps prevent overfitting
    output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    log_probs = tf.nn.log_softmax(logits, axis=-1)
    probs = tf.nn.softmax(logits, axis=-1)

    # Convert labels into one-hot encoding
    one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)

    predicted_labels = tf.squeeze(tf.argmax(log_probs, axis=-1, output_type=tf.int32))
    # If we're predicting, we want predicted labels and the probabiltiies.
    if is_predicting:
      return (predicted_labels, log_probs, probs, pooled_output)

    # If we're train/eval, compute loss between predicted and actual label
    per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)
    loss = tf.reduce_mean(per_example_loss)
    return (loss, predicted_labels, log_probs, probs, pooled_output)

Теперь в model_fn_builder() добавлена ​​поддержка этих значений:

  # this should be changed in both places
  (predicted_labels, log_probs, probs, pooled_output) = create_model(
    is_predicting, input_ids, input_mask, segment_ids, label_ids, num_labels)

  # return dictionary of all the values you wanted
  predictions = {
      'log_probabilities': log_probs,
      'probabilities': probs,
      'labels': predicted_labels,
      'pooled_output': pooled_output
  }

Настройте getPrediction() соответствующим образом, и в итоге ваши прогнозы будут выглядеть следующим образом:

('That movie was absolutely awful',
  array([0.99599314, 0.00400678], dtype=float32),  <= Probability
  array([-4.0148855e-03, -5.5197663e+00], dtype=float32), <= Log probability, same as previously
  'Negative', <= Label
  array([ 0.9181199 ,  0.7763732 ,  0.9999883 , -0.93533266, -0.9841384 ,
          0.78126144, -0.9918988 , -0.18764131,  0.9981035 ,  0.99999994,
          0.900716  , -0.99926263, -0.5078789 , -0.99417543, -0.07695035,
          0.9501321 ,  0.75836045,  0.49151263, -0.7886792 ,  0.97505844,
         -0.8931161 , -1.        ,  0.9318583 , -0.60531116, -0.8644371 ,
        ...
        and this is 768-d [CLS] vector (sentence embedding).    

Относительно 2): В моем конце тренировка заняла около 5 минут, а тестирование - около 40 секунд. Очень разумно.

ОБНОВЛЕНИЕ

Для 20-килограммовых образцов потребовалось 12:48 для обучения и 2:07 минут для тестирования.

Для таймингов для 10k образцов 8:40 и 1:07 соответственно.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...