Проблема здесь:
if(nums[mid] > nums[hi])
Вы не можете использовать это условие безопасно, и изменение его на использование >=
или в сочетании с nums[lo]
, на <
или <=
также не поможет; рассмотрим следующие две последовательности:
{ 1, 1, 1, 2, 1 }
{ 1, 2, 1, 1, 1 }
0 2 4
lo mid hi
С любой комбинацией lo
или hi
с любым из приведенных выше компараторов, одна из двух последовательностей будет оценена неправильно из-за выбора плохой половины! Вы попали именно в такую ситуацию с вашей последовательностью { 1, 1, 3, 1 }
.
Так что, если вы окажетесь в такой ситуации (nums[x]
будет равным для всех lo
, hi
и mid
), вам придется проверить обе половины! Полагаю, два рекурсивных вызова проще всего реализовать. Имейте в виду, что вам не нужно повторяться, если любое из значений трех индексов отличается; если вам нужно выполнить повторение, только для одного из двух промежуточных результатов реальный преемник будет меньше, поэтому этот результат будет окончательным.
Вы можете обнаружить e. г. следующим образом:
if(nums[mid] > nums[hi] || nums[mid > nums[lo])
{
hi = mid - 1;
}
else if(nums[mid] < nums[lo] || nums[mid] < nums[high]
{
lo = mid + 1;
}
else
{
int idxLo = findPivotRecursively(nums, lo, mid - 1);
int idxHi = findPivotRecursively(nums, mid + 1, high);
return nums[idxHi] > nums[(idxHi + 1) % n] ? idxHi : idxLo;
}
При условии, что вы превратили функцию в рекурсивную с подписью
int findPivotRecursively(std::vector const& nums, int high, int low);
Поскольку массив отсортирован, это последнее условие может быть истинным только для idxLo
или idxHigh
; выбранный мной тест приведет к выбору первого элемента вектора, если все элементы равны (nums[idxLo] > nums[(idxLo + 1] ? /*...*/;
выберет последний).
Вспомогательная функция обеспечивает открытый интерфейс:
int findPivot(std::vector const& nums)
{
return findPivotRecursively(nums, 0, nums.size() - 1);
}
С помощью лямбды вы можете полностью скрыть работающую функцию. К сожалению, рекурсивные лямбды не поддерживаются напрямую, поэтому вы должны полагаться на небольшую хитрость, как показано в этом ответе :
int findPivot(std::vector const& nums)
{
size_t n = nums.size();
auto findPivotRecursively = [&nums, &n] (int hi, int lo, auto& fpr) -> int
{
// ...
else
{
int idxLo = fpr(hi, lo, fpr);
// ...
}
// ...
};
return findPivotRecursively(0, nums.size() - 1, findPivotRecursively);
}