Вы используете массив numpy? Это работает для меня, когда и данные, и метки являются массивами.
>>> import numpy
>>> a = numpy.arange(24).reshape((4,3,2))
>>> a
array([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[ 6, 7],
[ 8, 9],
[10, 11]],
[[12, 13],
[14, 15],
[16, 17]],
[[18, 19],
[20, 21],
[22, 23]]])
>>> labels = numpy.array([1,0,1, 0])
>>> a[labels > 0]
array([[[ 0, 1],
[ 2, 3],
[ 4, 5]],
[[12, 13],
[14, 15],
[16, 17]]])
>>> a[labels == 0]
array([[[ 6, 7],
[ 8, 9],
[10, 11]],
[[18, 19],
[20, 21],
[22, 23]]])