Невозможно исправить / смоделировать функции экземпляра, которые были вызваны одной из функций экземпляра, для которых я пишу тестовый блок - PullRequest
0 голосов
/ 22 июня 2019

У меня есть модуль billing/billing-collect-project-license, который имеет LicenseStatistics класс. Которые имеют вызовы Redis, ORMRDS, CE и другие модули, которые используются в этом классе. Ниже приведен класс LicenseStatistics, где get_allocate_count - это метод экземпляра, который вызывает ce_obj.get_set_ce_obj, get_license_id и многие другие.

Метод get_license_id вызывает get_customer_details.

class LicenseStatistics():
"""This class encapsulates all logic related to license usage."""

    def __init__(self):
        self.logger = LOGGER

        # RedisNode List should be in the following format:
        #   HOST1:PORT1,HOST2:PORT2,HOST3:PORT3 etc
        redis_node_list = os.environ.get('REDIS_NODE_LIST', '').split(',')
        self.redis_utils = RedisUtils(redis_node_list)

        # DB object to read customer details
        dbhost = os.environ.get('DBHOST')
        dbuser = os.environ.get('DBUSER')
        dbpass = os.environ.get('DBPASSWORD')
        dbname = os.environ.get('DBNAME')
        dbport = os.environ.get('DBPORT')
        self.dbutils = ORMDBUtils(dbhost, dbuser, dbpass, dbname, dbport)
        self.ce_obj = CE()
        self.bill = Billing()

    def get_license_id(self, project_id):
        """
        @Summary: Get License Id for customer/project from customer table by
        project id
        @param project_id (string): CE project Id
        @return (string): CE License Id which associate with Project.
        """
        # Get license ID from RDS
        customer_details = self.get_customer_details(project_id)
        print("customer_details:", customer_details)
        license_id = customer_details["license_id"]
        if not license_id:
            msg = "No license for project {}".format(project_id)
            self.logger.error(msg)
            raise InvalidParamException(msg)

        print("license_id:", license_id)
        return  license_id

    def get_customer_details(self, project_id):
        """
        @Summary: Get Customer/Project details from customer table
        @param project_id (string): CE project Id
        @return (dictionary): Customer details from customer table.
        """
        filters = {"project_id": project_id}
        customer_details = self.dbutils.get_items(
            table_name=RDSCustomerTable.TABLE_NAME.value,
            columns_to_select=["account_id", "license_id"],
            filters=filters
        )
        if not customer_details:
            msg = "No customer found for project {}".format(project_id)
            self.logger.error(msg)
            raise NoCustomerException(msg)

        return customer_details[0]

    def is_shared_license(self, license_id):

        # This function return True or False  
        pass

    def get_project_machines_count(self, project_id, license_id):
        # This function return number of used license.
        count = 20
        return count

    def get_license_usage(self, project_id, license_id):
        # This function return number of machines used project.
        count = 10
        return count

    def get_allocate_count(self, project_id):
        """
        @Summary: Get number of licenses are used by Project.
        @param project_id (string): CloudEndure Project Id.
        @return (int): Number of license are used in Project.
        """
        # set Session get_customer_detailsfrom Redis
        status = self.ce_obj.get_set_ce_obj(
            project_id=project_id, redis_utils=self.redis_utils
        )
        print("license_id status--:", status)
        if not status:
            msg = "Unable to set CEproject {}".format(project_id)
            self.logger.critical(msg)
            raise InvalidParamException(msg, "project_id", project_id)

        print("project_id------------:", project_id)
        # Get license Id
        license_id = self.get_license_id(project_id)
        print("license_id:", license_id)
        # Check license is shared
        shared_flag = self.is_shared_license(license_id)
        if not shared_flag:
            # Get license usage
            licenses_used = self.get_license_usage(project_id, license_id)
        else:
            # Get machine account
            licenses_used = self.get_project_machines_count(
                project_id, license_id
            )

        return licenses_used

Я пишу юнит-тест для метода get_allocate_count, я издеваюсь над Redis, ORMRDS, Custom Exception, Logger. Эта функция вызывает ce_obj.get_set_ce_obj функцию, которая возвращает True/False. Я должен смоделировать / исправить возвращаемое значение этой функции успешно.
Но когда вызов переходит к следующему вызову функции, т.е. get_license_id, вызов переходит в реальный вызов функции и из-за неправильных входных данных. Я не могу патчить / издеваться

Ниже приведен код модульного теста:

import responses
import unittest
from unittest.mock import patch

import os
import sys

cwd_path = os.getcwd()
sys.path.append(cwd_path)

sys.path.append(cwd_path+"/../sam-cfns/code")
sys.path.append(cwd_path+"/../sam-cfns/code/billing")

from unit_tests.common.mocks.env_mock import ENV_VAR
from unit_tests.common.mocks.logger import FakeLogger
from unit_tests.common.mocks.cache_mock import RedisUtilsMock
from unit_tests.common.mocks.ormdb_mock import ORMDBUtilsMockProject
from unit_tests.common.mocks.exceptions_mock import NoCustomerExceptionMock
from unit_tests.common.mocks.exceptions_mock import BillingExceptionMock
from unit_tests.common.mocks.exceptions_mock import InvalidParamExceptionMock
from unit_tests.common.mocks.api_responses import mock_response
from unit_tests.common.examples import ce_machines_data
from unit_tests.common.examples import ce_license_data
from unit_tests.common.examples import ce_data


class BillingTest(unittest.TestCase):
    """ Billing TEST class drive from UnitTest """

    @patch("billing-collect-project-license.Logger", FakeLogger)
    @patch("os.environ", ENV_VAR)
    @patch("billing-collect-project-license.RedisUtils", RedisUtilsMock)
    @patch("billing-collect-project-license.ORMDBUtils", ORMDBUtilsMockProject)
    @patch("exceptions.NoCustomerException", NoCustomerExceptionMock)
    @patch("billing.billing_exceptions.BillingException", BillingExceptionMock)
    @patch("billing.billing_exceptions.InvalidParamException", InvalidParamExceptionMock)
    def __init__(self, *args, **kwargs):
        """Initialization"""
        super(BillingTest, self).__init__(*args, **kwargs)
        billing_collect_project_license_module = (
            __import__("cbr-billing-collect-project-license")
        )
        self.licenses_stats_obj = (
            billing_collect_project_license_module.LicenseStatistics()
        )

class BillingCollectProjectLicense(BillingTest):
    """Login Unit Test Cases"""
    def __init__(self, *args, **kwargs):
        """Initialization"""
        super(BillingCollectProjectLicense, self).__init__(*args, **kwargs)

    def setUp(self):
        """Setup for all Test Cases."""
        pass


    #@patch("billing.cbr-billing-collect-project-license.LicenseStatistics."
    #       "get_project_machines_count")
    #@patch("billing.cbr-billing-collect-project-license.LicenseStatistics."
    #       "get_customer_details")
    #@patch("billing.cbr-billing-collect-project-license.LicenseStatistics.get_license_id")
    @patch("billing.cbr-billing-collect-project-license.LicenseStatistics.get_license_id")
    @patch("cbr.ce.CloudEndure.get_set_ce_obj")
    def test_get_allocate_count(self, get_set_ce_obj_mock, get_license_id_mock):
        project_id = ce_data.CE_PROJECT_ID
        license_id = ce_license_data.LICENSE_ID
        get_set_ce_obj_mock.return_value = True
        get_license_id_mock.return_value = license_id
      # LicenseStatistics_mock.return_value.get_license_id.return_value = license_id
        #get_license_id_mock.return_value = license_id
       # self.licenses_stats_obj.get_license_id = get_license_id_mock
        get_customer_details_mock = {"license_id": license_id}
     #   is_shared_license_mock.return_value = True
     #   get_project_machines_count_mock.return_value = 20

        resp = self.licenses_stats_obj.get_allocate_count(
            project_id
        )
        self.assertEqual(resp, 20)


if __name__ == '__main__':
    unittest.main()

Я не могу исправить функцию get_license_id из того же класса. Это фактически вызывает функцию get_license_id и не работает. Я хочу смоделировать возвращаемое значение функции get_license_id.

Кто-нибудь поможет мне? Спасибо.

1 Ответ

0 голосов
/ 10 июля 2019

Проблема в том, что я инициализирую его в init , поэтому методы монопатчинга класса LicenseStatistics позже не влияют на уже созданный экземпляр. @ Хефлинг

С помощью Patching от Monkey я смог успешно запустить Test Cases.

пример кода:

def test_get_allocate_count_ok_4(self, ):
    """
    @Summary: Test case for successful response for shared license
    by other unittest method - Monkey Patching
    """
    def get_customer_details_mp(_):
        """Monkey Patching function for get_customer_details"""
        data = [
            {
                "account_id": "abc",
                "project_id": "abc",
                "license_id": "abc",
                "status": "Active"
            }
        ]
        return data
    def get_set_ce_obj_mp(_, _tmp):
        """Monkey Patching function for get_set_ce_obj"""
        return True
    def get_license_id_mp(_):
        """Monkey Patching function for get_license_id"""
        return "abc"
    def is_shared_license_mp(_):
        """Monkey Patching function for is_shared_license"""
        return True
    def get_project_machines_count_mp(_, _license_id):
        """Monkey Patching function for get_project_machines_count"""
        return 5

    project_id = "abc"

    # Monkey Patching
    self.licenses_stats_obj.get_customer_details = get_customer_details_mp
    self.licenses_stats_obj.ce_obj.get_set_ce_obj = get_set_ce_obj_mp
    self.licenses_stats_obj.get_license_id = get_license_id_mp
    self.licenses_stats_obj.is_shared_license = is_shared_license_mp
    self.licenses_stats_obj.get_project_machines_count = (
        get_project_machines_count_mp
    )

    resp = self.licenses_stats_obj.get_allocate_count(project_id)
    self.assertEqual(resp, 5)
...