DFS для поиска подключенных компонентов
import queue
import itertools
n = 10
def DFS(data, v, x,y,z, component):
q = queue.Queue()
q.put((x,y,z))
while not q.empty():
x,y,z = q.get()
v[x,y,z] = component
l = [[x], [y], [z]]
for i in range(3):
if l[i][0] > 0:
l[i].append(l[i][0]-1)
if l[i][0] < v.shape[1]-1:
l[i].append(l[i][0]+1)
c = list(itertools.product(l[0], l[1], l[2]))
for x,y,z in c:
if v[x,y,z] == 0 and data[x,y,z] == 1:
q.put((x,y,z))
data = np.random.binomial(1, 0.2, n*n*n)
data = data.reshape((n,n,n))
coordinates = np.argwhere(data > 0)
v = np.zeros_like(data)
component = 1
for x,y,z in coordinates:
if v[x,y,z] != 0:
continue
DFS(data, v, x,y,z, component)
component += 1
Основной алгоритм:
- Установить посещение каждой точки = 0 (это означает, что она не является частью какого-либо подключенного
компонента пока нет)
- для всех точек, значение которых == 1
- Если точка не посещена, запустите DFS, начиная с формы
DFP: : Это традиционный алгоритм DFS, использующий очередь. Единственное различие для трехмерного случая дано (x,y,z)
, мы рассчитываем всех действительных соседей, используя itertools.product
В трехмерном случае каждая точка будет иметь 27 соседей, включая себя (3 позиции и 3 возможных значения - то же самое, приращение, уменьшение, то есть 27 путей).
В матрице v
хранятся подключенные компоненты, пронумерованные начиная с 1.
TestCase:
когда данные =
[[[1 1 1]
[1 1 1]
[1 1 1]]
[[0 0 0]
[0 0 0]
[0 0 0]]
[[1 1 1]
[1 1 1]
[1 1 1]]]
Визуализация:
две противоположные стороны - это два разных соединенных компонента
Алгоритм возвращает v
[[[1 1 1]
[1 1 1]
[1 1 1]]
[[0 0 0]
[0 0 0]
[0 0 0]]
[[2 2 2]
[2 2 2]
[2 2 2]]]
что правильно.
Визуализация:
Как видно из визуализации v
зеленый цвет обозначает один подключенный компонент, а синий цвет обозначает другой подключенный компонент.
Код визуализации
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
def plot(data):
fig = plt.figure(figsize=(10,10))
ax = fig.gca(projection='3d')
for i in range(data.shape[0]):
for j in range(data.shape[1]):
ax.scatter([i]*data.shape[0], [j]*data.shape[1],
[i for i in range(data.shape[2])],
c=['r' if i == 0 else 'b' for i in data[i,j]], s=50)
plot(data)
plt.show()
plt.close('all')
plot(v)
plt.show()