Amazon SageMaker kMeans не будет принимать разреженную матрицу (csr_matrix) в качестве входных данных, какие-либо альтернативы перед использованием плотной матрицы? - PullRequest
0 голосов
/ 16 декабря 2018

Я хочу применить алгоритм kMeans sagemaker к разреженной матрице, полученной с помощью TfidfVectorizer из библиотеки sklearn.

В идеале я хотел бы предоставить входные данные для реализации kageans Sagemaker в качестве разреженной матрицыscipy.sparse.csr.csr_matrix, но когда я это (kmeans.fit(kmeans.record_set(train_data))), я получаю следующую ошибку:

TypeError: must be real number, not csr_matrix

Конечно, если я передам плотную матрицу, алгоритм будет работать (train_data.toarray()) но объем памяти, который ему понадобится, огромен.Какие-нибудь возможные альтернативы, прежде чем я начну использовать суперразмерные экземпляры Amazon?

1 Ответ

0 голосов
/ 19 декабря 2018

Ключ был в SageMaker python SDK.Там вы можете найти функцию, которая преобразует скудную разреженную матрицу в разреженный тензор (write_spmatrix_to_sparse_tensor).

Полный код, который решил проблему без необходимости вставлять в плотную матрицу, следующий:

from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor

tfidf_matrix = tfidf_vectorizer.fit_transform('your_train_data') # output: sparse scipy matrix
sagemaker_bucket = 'your-bucket' 
data_key = 'kmeans_lowlevel/data'
data_location = f"s3://{sagemaker_bucket}/{data_key}"
buf = io.BytesIO()
write_spmatrix_to_sparse_tensor(buf, tfidf_matrix)
buf.seek(0)
boto3.resource('s3').Bucket(sagemaker_bucket).Object(data_key).upload_fileobj(buf)

После этого в конфигурации create_training_params вам нужно будет заполнить поле S3Uri указанным вами местоположением данных для хранения разреженной матрицы в S3:

create_training_params = \
{
    ... # all other params

    "InputDataConfig": [
        {
            "ChannelName": "train",
            "DataSource": {
                "S3DataSource": {
                    "S3DataType": "S3Prefix",
                    "S3Uri": data_location, # YOUR_DATA_LOCATION_GOES_HERE
                    "S3DataDistributionType": "FullyReplicated"
                }
            },
            "CompressionType": "None",
            "RecordWrapperType": "None"
        }
    ]
}
...