Еще один метод, основанный на заполнении и перекате:
def sum_shifted(arr, direction=1):
n = arr.shape[0]
temp = np.zeros((n, 2 * n - 1), dtype=arr.dtype)
temp[:, slice(None, n) if direction == 1 else slice(-n, None)] = arr
for i in range(n):
temp[i, :] = np.roll(temp[i, :], direction * i)
return np.sum(temp, 0)[::direction]
Это дает вам множество вариантов.По быстродействию методы @ Divakar, кажется, имеют преимущество:
Эти графики были созданы с помощью этого сценария с использованием их в качестве тестовых функций:
def sum_shifted(arr, direction=1):
n = arr.shape[0]
temp = np.zeros((n, 2 * n - 1), dtype=arr.dtype)
temp[:, slice(None, n) if direction == 1 else slice(-n, None)] = arr
for i in range(n):
temp[i, :] = np.roll(temp[i, :], direction * i)
return np.sum(temp, 0)[::direction]
def sum_shifted_both(arr):
return sum_shifted(arr, 1), sum_shifted(arr, -1)
def sum_adam(arr):
return (
np.array([np.sum(np.diag(np.fliplr(arr), d)) for d in range(len(arr) - 1, -len(arr), -1)]),
np.array([np.sum(np.diag(arr, d)) for d in range(len(arr) - 1, -len(arr), -1)]))
def sum_divakar(a):
n = len(a)
N = 2*n-1
R = np.arange(N)
r = np.arange(n)
mask = (r[:,None] <= R) & (r[:,None]+n > R)
b_leftdiag = np.zeros(mask.shape,dtype=a.dtype)
b_leftdiag[mask] = a.ravel()
b_rightdiag = np.zeros(mask.shape,dtype=a.dtype)
b_rightdiag[mask[:,::-1]] = a.ravel()
return b_leftdiag.sum(0), b_rightdiag.sum(0)[::-1]
def sum_divakar2(a):
def left_sum(a):
n = len(a)
N = 2*n-1
p = np.zeros((n,n),dtype=a.dtype)
ap = np.concatenate((a,p),axis=1)
return ap.ravel()[:n*N].reshape(n,-1).sum(0)
return left_sum(a), left_sum(a[::-1])[::-1]
и в качестве вспомогательных функций:
def gen_input(n):
return np.arange(n * n).reshape((n, n))
def equal_output(out_a, out_b):
return all(
np.all(a_arr == b_arr)
for a_arr, b_arr in zip(out_a, out_b))
input_sizes=(5, 10, 50, 100, 500, 1000, 5000)
funcs = sum_shifted_both, sum_adam, sum_divakar, sum_divakar2
runtimes, input_sizes, labels, results = benchmark(
funcs, gen_input=gen_input, equal_output=equal_output, input_sizes=input_sizes)
plot_benchmarks(runtimes, input_sizes, labels)