Как обрабатывать ввод с плавающей запятой и массив в Python? - PullRequest
0 голосов
/ 24 сентября 2019

Я пытаюсь написать функцию, которая принимает либо float, либо массив float, и обрабатывает их обоих с использованием одинаковых строк кода.Например, я хочу вернуть сам float, если это float, и сумму массива float, если это массив.Примерно так

def func(a):
  return np.sum(a)

и оба func(1.2) возвращают 1,2, а func(np.array([1.2,1.3,1.4]) возвращают 3,9.

Ответы [ 4 ]

1 голос
/ 24 сентября 2019

Обычный способ убедиться, что вход является массивом NumPy, это использовать np.asarray():

import numpy as np

def func(a):
  a = np.asarray(a)
  return np.sum(a)

func(1.2)
# 1.2
func([1.2, 3.4])
# 4.6
func(np.array([1.2, 3.4]))
# 4.6

или, если вы хотите получить len() вашего массива, убедитесь, что он равеннаименее одномерный, используйте np.atleast_1d():

def func(a):
  a = np.atleast_1d(a)
  return a.shape[0]

func(1.2)
# 1
func([1.2, 3.4])
# 2
func(np.array([1.2, 3.4]))
# 2
1 голос
/ 24 сентября 2019

Вы можете использовать выравнивание аргументов:

def func(*args):
    # code to handle args
    return sum(args)

Теперь следующие действия имеют такое же поведение:

>>> func(3)
3
>>> func(3, 4, 5)
12
>>> func(*[3, 4, 5])
12
1 голос
/ 24 сентября 2019

Это уже работает, в чем проблема?

import numpy as np
def func(a):
  return np.sum(a)
print(func(np.array([1.2,2.3,3.2])))
print(func(1.2))

Вывод:

6.7
1.2
0 голосов
/ 24 сентября 2019

Вы можете проверить, является ли ввод плавающим, а затем поместить его в список перед обработкой суммы:

def func(a):
    if isinstance(a, float):
        a = [a]
    return np.sum(a)
...