Python - не удается смоделировать вызов для унаследованного класса - PullRequest
0 голосов
/ 18 сентября 2018

У меня есть этот основной класс

def main(args):
    if type == train_pipeline_type:
        strategy = TrainPipelineStrategy()
    else:
        strategy = TestPipelineStrategy()
    for table in fetch_table_information_by_region(region):
        split_required = DataUtils.load_from_dict(table, "split_required")
        if split_required:
            strategy.split(spark=spark, table_name=table_name,
                           data_loc=filtered_data_location, partition_column=partition_column,
                           split_output_dir= split_output_dir)
            logger.info("Data Split for table : {} completed".format(table_name))

Мой TrainPipelineStrategy и TestPipelineStrategy выглядят так -

class PipelineTypeStrategy(object):

    def partition_data(self, x):
        # Something

    def prepare_split_data(self, y):
        # Something

    def write_split_data(self, z):
        # Something

    def split(self, p):
        # Something


class TrainPipelineStrategy(PipelineTypeStrategy):
    """"""


class TestPipelineStrategy(PipelineTypeStrategy):

    def write_split_data(self, y):
        # Something else

Мой тестовый пример - мне нужно проверить, сколько раз вызывается splitФункциональность mocking split в методе main.

Вот то, что я пробовал -

@patch('module.PipelineTypeStrategy.TrainPipelineStrategy')
    def test_split_data_main_split_data_call_count(self, fake_train):
        fake_train_functions = mock.Mock()
        fake_train_functions.split.return_value = None
        fake_train.return_value = fake_train_functions
        test_args = ["", "--x=6"]
        SplitData.main(args=test_args)
        assert fake_train_functions.split.call_count == 10

Когда я пытаюсь запустить свой тест, он создает mock, но в конечном итоге вызывает фактическую функцию split,Что я делаю не так?

1 Ответ

0 голосов
/ 19 сентября 2018

Основная проблема с этим кодом заключается в том, что способ установки patch был бы, если бы TrainPipelineStrategy был вложенным классом PipelineTypeStrategy, а TrainPipelineStrategy - это подкласс PipelineTypeStrategy.

Поскольку TrainPipelineStrategy наследуется от PipelineTypeStrategy, у него есть доступ к split напрямую, поэтому вы можете исправлять split без какой-либо ссылки на PipelineTypeStrategy (если вы не хотите специально исправлять версию split, определенную вPipelineTypeStrategy).

Однако, если вы просто хотите высмеивать метод split класса PipelineTypeStrategy, вы должны использовать декоратор patch.object, чтобы высмеивать только split вместо насмешки над всемкласс, так как он немного более чистый.Вот пример:

class TestClass(unittest.TestCase):
    @patch.object(TrainPipelineStrategy, 'split', return_value=None)
    def test_split_data_main_split_data_call_count(self, mock_split):
        test_args = ["", "--x=6"]
        SplitData.main(args=test_args)
        self.assertEqual(mock_split.call_count, 10)
...