KeyError: 0 при обучении с помощью Stellargraph generator.flow - PullRequest
0 голосов
/ 09 июля 2020

Я пробую классификацию узлов с использованием класса gcn библиотеки stellargraph. Итак, я импортировал функции узлов node_feat.csv, метки узлов как node_label.csv и функции ребер как edge_feat.csv. Следуя процедуре, приведенной в https://stellargraph.readthedocs.io/en/stable/demos/node-classification/gcn-node-classification.html для классификации узлов.

!wget -O node_feat.csv https://github.com/pranavn91/blockchain/blob/master/tx2009partvertices_new.csv
!wget -O node_targets.csv https://github.com/pranavn91/blockchain/blob/master/tx2009partvertices.csv
!wget -O edge_data.csv https://github.com/pranavn91/blockchain/blob/master/tx2009partedges.csv
    

Затем я импортирую библиотеку

from stellargraph import StellarDiGraph as sg
import pandas as pd
import os
#import stellargraph as sg
from stellargraph.mapper import FullBatchNodeGenerator
from stellargraph.layer import GCN

from tensorflow.keras import layers, optimizers, losses, metrics, Model
from sklearn import preprocessing, model_selection
from IPython.display import display, HTML
import matplotlib.pyplot as plt
%matplotlib inline

Затем создаю звездно-ориентированный граф

trans2009 = sg(
    {"users": node_feat}, {"transfer_btc": edge_data}
)
print(trans2009.info())

разделить набор данных

train_subjects, test_subjects = model_selection.train_test_split(
    node_targets, train_size=40, test_size=None
)
val_subjects, test_subjects = model_selection.train_test_split(
    node_targets, train_size=6, test_size=None
)

target_encoding = preprocessing.LabelBinarizer()

train_targets = target_encoding.fit_transform(train_subjects["label"].astype(str))
val_targets = target_encoding.transform(val_subjects["label"].astype(str))
test_targets = target_encoding.transform(test_subjects["label"].astype(str))

generator = FullBatchNodeGenerator(G, method="gcn")

Но для шага, указанного ниже

train_gen = generator.flow(train_subjects.index, train_targets, shuffle=True)

Получить keyError: 0

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-31-74331d853607> in <module>
----> 1 train_gen = generator.flow(train_subjects.index, train_targets, shuffle=True)

/opt/conda/lib/python3.7/site-packages/stellargraph/mapper/sampled_node_generators.py in flow(self, node_ids, targets, shuffle, seed)
    139             expected_node_type = None
    140 
--> 141         node_ilocs = self.graph.node_ids_to_ilocs(node_ids)
    142         node_types = self.graph.node_type(node_ilocs, use_ilocs=True)
    143         invalid = node_ilocs[node_types != expected_node_type]

/opt/conda/lib/python3.7/site-packages/stellargraph/core/graph.py in node_ids_to_ilocs(self, nodes)
   1211             Numpy array containing the indices for the requested nodes.
   1212         """
-> 1213         return self._nodes.ids.to_iloc(nodes, strict=True)
   1214 
   1215     def node_ilocs_to_ids(self, node_ilocs):

/opt/conda/lib/python3.7/site-packages/stellargraph/core/element_data.py in to_iloc(self, ids, smaller_type, strict)
     95         internal_ids = self._index.get_indexer(ids)
     96         if strict:
---> 97             self.require_valid(ids, internal_ids)
     98 
     99         # reduce the storage required (especially useful if this is going to be stored rather than

/opt/conda/lib/python3.7/site-packages/stellargraph/core/element_data.py in require_valid(self, query_ids, ilocs)
     75 
     76             if len(missing_values) == 1:
---> 77                 raise KeyError(missing_values[0])
     78 
     79             raise KeyError(missing_values)

KeyError: 0

Как решить?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...