Я делаю несколько моделей LSTM с разными параметрами и хочу сохранить их с уникальными и узнаваемыми именами, но получаю ошибку. Я попробовал все, что "Google" сказал мне, но ничего не получалось. Возможно, я неправильно пишу имя файла, но я новичок в python и не могу сам это исправить.
seq_lens = [4,5,6,7] dropout_rates = [0.1, 0.2, 0.3, 0.4, 0.5] num_hl
= [1,2,3,4] node_hl = [10, 15, 20, 25] learning_rates = [0.01, 0.001, 0.0001, 0.00001]
for current_seq_len in seq_lens:
for current_drop_rate in dropout_rates:
for current_number_hls in num_hl:
for current_node_hls in node_hl:
for current_learning_rate in learning_rates:
current_name_model = "seq_len:" + str(current_seq_len) + "_" + "dr_rate:" + str(current_drop_rate) + "_"
+ "num_hl:" + str(current_number_hls) + "_" + "node_hl:" + str( current_node_hls) + "_" + "learn_rt:" + str(current_learning_rate)
X_train, y_train = create_data(train_dataset, current_seq_len)
layers_size = [1]
for hl in range(3): # +1 because first layer is input layer, not hidden layer
layers_size.extend([current_seq_len])
layers_size.extend([1])
model = Sequential()
model.add(LSTM(input_shape=((X_train.shape[1], X_train.shape[2])),units=layers_size[1], return_sequences=True))
model.add(Dropout(current_drop_rate))
# hidden layers
for hl in range(current_number_hls+1): # +1 because first layer is input layer, not hidden layer
model.add(LSTM(layers_size[hl+1], return_sequences=True, input_shape=((X_train.shape[1], X_train.shape[2]))))
model.add(Dropout(current_drop_rate))
# output layer
model.add(TimeDistributed(Dense(units=layers_size[-1])))
model.add(AveragePooling1D())
model.add(Flatten())
model.add(Dense(units = 1,activation='sigmoid'))
opt = Adam(lr=current_learning_rate)
model.compile(loss='mean_squared_error', optimizer=opt)
model.fit(X_train, y_train, batch_size=32, epochs=200, validation_split=0.1, verbose = False)
makedirs('models')
filename = 'models/model_' + current_name_model + '.h5'
model.save(filename)
print('>Saved %s' % filename)
Я получаю эту ошибку:
OSError Traceback (most recent call last)
<ipython-input-87-74935647b8eb> in <module>
47 # save model
48 filename = 'models/model_' + current_name_model + '.h5'
---> 49 model.save(filename)
50 print('>Saved %s' % filename)
51 #!mkdir -p saved_model
~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras\engine\network.py in save(self, filepath, overwrite, include_optimizer)
1088 raise NotImplementedError
1089 from ..models import save_model
-> 1090 save_model(self, filepath, overwrite, include_optimizer)
1091
1092 def save_weights(self, filepath, overwrite=True):
~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras\engine\saving.py in save_model(model, filepath, overwrite, include_optimizer)
377 opened_new_file = False
378
--> 379 f = h5dict(filepath, mode='w')
380
381 try:
~\AppData\Local\Continuum\anaconda3\lib\site-packages\keras\utils\io_utils.py in __init__(self, path, mode)
184 self._is_file = False
185 elif isinstance(path, str):
--> 186 self.data = h5py.File(path, mode=mode)
187 self._is_file = True
188 elif isinstance(path, dict):
~\AppData\Roaming\Python\Python37\site-packages\h5py\_hl\files.py in __init__(self, name, mode, driver, libver, userblock_size, swmr, rdcc_nslots, rdcc_nbytes, rdcc_w0, track_order, **kwds)
406 fid = make_fid(name, mode, userblock_size,
407 fapl, fcpl=make_fcpl(track_order=track_order),
--> 408 swmr=swmr)
409
410 if isinstance(libver, tuple):
~\AppData\Roaming\Python\Python37\site-packages\h5py\_hl\files.py in make_fid(name, mode, userblock_size, fapl, fcpl, swmr)
177 fid = h5f.create(name, h5f.ACC_EXCL, fapl=fapl, fcpl=fcpl)
178 elif mode == 'w':
--> 179 fid = h5f.create(name, h5f.ACC_TRUNC, fapl=fapl, fcpl=fcpl)
180 elif mode == 'a':
181 # Open in append mode (read/write).
h5py\_objects.pyx in h5py._objects.with_phil.wrapper()
h5py\_objects.pyx in h5py._objects.with_phil.wrapper()
h5py\h5f.pyx in h5py.h5f.create()
OSError: Unable to create file (unable to open file: name = 'models/model_seq_len:4_dr_rate:0.1_num_hl:1_node_hl:10_learn_rt:0.01.h5', errno = 22, error message = 'Invalid argument', flags = 13, o_flags = 302)