Запустив ваш скрипт с [1,2,3,4,5] и [4,6,8,10,12], я вижу, что оптимизация расходится и выдает «nan» значения для W и b; Я предполагаю, что это проблема, которую вы имеете в виду под "сломанным" результатом.
Основные две проблемы заключаются в том, что размер вашего шага (альфа = 0,1) очень агрессивен и что инициализация по умолчанию для параметров W и B, вероятно, не оптимальна. Это всегда контекстно-зависимая оценка - нет универсального параметра альфа, подходящего для каждой задачи оптимизации, - поэтому вы должны взглянуть на шаги конвергенции. Модифицируя ваш скрипт для печати значений после каждой итерации, вот первые несколько:
ORIGINAL VALUES:
(-1, 1046.1516, array([-5.939405], dtype=float32), array([-4.5157075], dtype=float32))
OPTIMIZING
(0, 1949.8801, array([15.435911], dtype=float32), array([1.5510774], dtype=float32))
(1, 3636.85, array([-13.854544], dtype=float32), array([-6.420686], dtype=float32))
(2, 6785.708, array([26.077063], dtype=float32), array([4.7761774], dtype=float32))
Видите, насколько велики различия в параметрах? Значения параметров в конечном итоге изменяются в больших долях - до 350%! Градиент стоимости наименьших квадратов в этом одномерном случае для параметра W равен
D[cost]/D[W] = 2W/N * sum_i (W x_i + B - y_i)
или эквивалентно
1/W D[cost]/D[W] = 2 (W <x> + B - <y>)
Итак, начальная ошибка (W <x> + B - <y>
) от начальных случайных значений W и B (-5,9 и -4,5) равна -35 ---, а альфа равна 0,1, параметр W будет изменен на дробное количество около -350% (100% * -35 * 0,1). Вот почему W увеличивается от -5,9 до 15,4.
Итак, две проблемы:
- начальные значения W и B кажутся большими. Возможно, вы захотите попробовать другой механизм инициализации. Я не знаю из рук, какая рекомендуемая процедура, но, возможно, tf.global_variables_initializer не лучший в этом случае
- что более важно, ваш параметр обучения альфа слишком велик. Попробуйте меньшее значение, например, 0,001; или попробуйте 0.1 с AdamOptimizer вместо GradientDescentOptimizer. AdamOptimizer должен лучше справляться с большими флуктуациями, которые вы видите, когда W колеблется от -5,9 => 15,4 => -13,8 => 26,1 и т. Д.