Я пытаюсь вычислить расстояние Вассерштейна между двумя диаграммами постоянства, сгенерированными через библиотеку Python Ripser. Я нашел две интересные функции в Persim: sliced_wasserstein и wasserstein_matching.
Генерация моей диаграммы выглядит так:
data = json.loads(data)
data = pd.DataFrame.from_dict(data)
rips = Rips()
dgms = rips.fit_transform(data)
for i in dgms:
print(type(i))
i.tofile(directory+"diagram.txt")
plot_diagrams(dgms, show=False)
plt.savefig("persistence_diagram.png")
plt.close()
'dgms' - это список, который содержит пустые массивы, поэтому я вывожу их в своей строке 'for'.
Использование моей функции Вассерштейна выглядит следующим образом:
with open(loc) as f:
img1 = np.fromfile(f)
f.close()
with open(loc2) as f:
img2 = np.fromfile(f)
f.close()
persim.sliced_wasserstein(img1, img2)
Я пытался передать wasserstein_matching три вида данных (диаграмма в .png, список dgms и np.array), но все, что я постоянно получаю, это ошибка «IndexError: слишком много индексов для массива».
Поэтому я переключился на sliced_wasserstein, где я получаю такую ошибку:
Traceback (most recent call last):
File "C:/Users/Patka/PycharmProjects/MGR/Mapper.py", line 26, in <module>
persim.sliced_wasserstein(img1, img2)
File "C:\Users\Patka\environmentpython\lib\site-packages\persim\sliced_wasserstein.py", line 53, in sliced_wasserstein
sw += step * cityblock(sorted(V1), sorted(V2))
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Одна странная вещь для меня заключается в том, что когда я печатаю i.shape перед сохранением в файл, я получаю два измерения, например (12,2), но когда я читаю из того же файла с помощью numpy.fromfile (), я получаю кортеж (12,).
У кого-нибудь есть лекарство от этого? Моя конечная цель - вычислить расстояния для множества диаграмм и сгруппировать их, но я застрял при сравнении двух ...