Как правильно реализовать непересекающуюся структуру данных для нахождения охватывающих лесов в Python? - PullRequest
1 голос
/ 31 января 2020

Недавно я пытался реализовать решения вопросов по программированию google kickstater на 2019 г. и попытался внедрить «Вишни Ме» раунда E sh, следуя объяснениям анализа. Вот ссылка на вопрос и анализ. https://codingcompetitions.withgoogle.com/kickstart/round/0000000000050edb/0000000000170721

Вот код, который я реализовал:

t = int(input())
for k in range(1,t+1):
    n, q = map(int,input().split())
    se = list()
    for _ in range(q):
        a,b = map(int,input().split())
        se.append((a,b))
    l = [{x} for x in range(1,n+1)]
    #print(se)
    for s in se:
        i = 0
        while ({s[0]}.isdisjoint(l[i])):
            i += 1
        j = 0
        while ({s[1]}.isdisjoint(l[j])):
            j += 1
        if i!=j:
            l[i].update(l[j])
            l.pop(j)
        #print(l)
    count = q+2*(len(l)-1)
    print('Case #',k,': ',count,sep='')



Здесь проходит пример, но не тесты. Насколько мне известно, это должно быть правильно. Я что-то не так делаю?

Ответы [ 2 ]

0 голосов
/ 31 января 2020

Вы получаете неправильный ответ, потому что вы неправильно рассчитываете счет. Для соединения n узлов в дерево требуется n-1 ребер, и num_clusters-1 из них должны быть красного цвета.

Но если вы это исправите, ваша программа все равно будет очень медленной из-за Ваша реализация дизъюнктных множеств.

К счастью, довольно легко реализовать очень эффективную структуру данных непересекающихся множеств в одном массиве / списке / векторе практически на любом языке программирования. Вот хороший в python. На моем ящике python 2, поэтому мои операторы печати и ввода немного отличаются от ваших:

# Create a disjoint set data structure, with n singletons, numbered 0 to n-1
# This is a simple array where for each item x:
# x > 0 => a set of size x, and x <= 0 => a link to -x

def ds_create(n):
    return [1]*n

# Find the current root set for original singleton index

def ds_find(ds, index):
    val = ds[index]
    if (val > 0):
        return index
    root = ds_find(-val)
    if (val != -root):
        ds[index] = -root # path compression
    return root

# Merge given sets. returns False if they were already merged

def ds_union(ds, a, b):
    aroot = ds_find(ds, a)
    broot = ds_find(ds, b)
    if aroot == broot:
        return False
    # union by size
    if ds[aroot] >= ds[broot]:
        ds[aroot] += ds[broot]
        ds[broot] = -aroot
    else:
        ds[broot] += ds[aroot]
        ds[aroot] = -broot
    return True

# Count root sets

def ds_countRoots(ds):
    return sum(1 for v in ds if v > 0)

#
# CherriesMesh solution
#
numTests = int(raw_input())
for testNum in range(1,numTests+1):
    numNodes, numEdges = map(int,raw_input().split())
    sets = ds_create(numNodes)
    for _ in range(numEdges):
        a,b = map(int,raw_input().split())
        print a,b
        ds_union(sets, a-1, b-1)
    count = numNodes + ds_countRoots(sets) - 2
    print 'Case #{0}: {1}'.format(testNum, count)
0 голосов
/ 31 января 2020

Две проблемы:

  • Ваш алгоритм проверки, связывает ли ребро два непересекающихся множества, и объединяет их, если нет, неэффективен. Алгоритм Union-Find для структуры данных Disjoint-Set более эффективен
  • Окончательное число не зависит от исходного числа черных ребер, поскольку у этих черных ребер могут быть циклы, и поэтому некоторые из них не должны учитываться. Вместо этого посчитайте, сколько ребер всего (независимо от цвета). Поскольку решение представляет собой минимальное остовное дерево, количество ребер равно n-1 . Вычтите из этого количество непересекающихся множеств, которые у вас есть (как вы уже сделали).

Я бы также посоветовал использовать значимые имена переменных. Код намного проще для понимания. Однобуквенные переменные, такие как t, q или s, не очень полезны.

Существует несколько способов реализации функций Union-Find. Здесь я определил класс Node, который имеет эти методы:

# Implementation of Union-Find (Disjoint Set)
class Node:
    def __init__(self):
        self.parent = self
        self.rank = 0

    def find(self):
        if self.parent.parent != self.parent:
            self.parent = self.parent.find()
        return self.parent

    def union(self, other):
        node = self.find()
        other = other.find()
        if node == other:
            return True # was already in same set
        if node.rank > other.rank:
            node, other = other, node
        node.parent = other
        other.rank = max(other.rank, node.rank + 1)
        return False # was not in same set, but now is

testcount = int(input())
for testid in range(1, testcount + 1):
    nodecount, blackcount = map(int, input().split())
    # use Union-Find data structure
    nodes = [Node() for _ in range(nodecount)]
    blackedges = []
    for _ in range(blackcount):
        start, end = map(int, input().split())
        blackedges.append((nodes[start - 1], nodes[end - 1]))

    # Start with assumption that all edges on MST are red:
    sugarcount = nodecount * 2 - 2
    for start, end in blackedges:
        if not start.union(end): # When edge connects two disjoint sets:
            sugarcount -= 1 # Use this black edge instead of red one

    print('Case #{}: {}'.format(testid, sugarcount))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...