У меня есть модуль billing/billing-collect-project-license
, который имеет LicenseStatistics
класс. Которые имеют вызовы Redis, ORMRDS, CE и другие модули, которые используются в этом классе. Ниже приведен класс LicenseStatistics
, где get_allocate_count
- это метод экземпляра, который вызывает ce_obj.get_set_ce_obj
, get_license_id
и многие другие.
Метод get_license_id
вызывает get_customer_details
.
class LicenseStatistics():
"""This class encapsulates all logic related to license usage."""
def __init__(self):
self.logger = LOGGER
# RedisNode List should be in the following format:
# HOST1:PORT1,HOST2:PORT2,HOST3:PORT3 etc
redis_node_list = os.environ.get('REDIS_NODE_LIST', '').split(',')
self.redis_utils = RedisUtils(redis_node_list)
# DB object to read customer details
dbhost = os.environ.get('DBHOST')
dbuser = os.environ.get('DBUSER')
dbpass = os.environ.get('DBPASSWORD')
dbname = os.environ.get('DBNAME')
dbport = os.environ.get('DBPORT')
self.dbutils = ORMDBUtils(dbhost, dbuser, dbpass, dbname, dbport)
self.ce_obj = CE()
self.bill = Billing()
def get_license_id(self, project_id):
"""
@Summary: Get License Id for customer/project from customer table by
project id
@param project_id (string): CE project Id
@return (string): CE License Id which associate with Project.
"""
# Get license ID from RDS
customer_details = self.get_customer_details(project_id)
print("customer_details:", customer_details)
license_id = customer_details["license_id"]
if not license_id:
msg = "No license for project {}".format(project_id)
self.logger.error(msg)
raise InvalidParamException(msg)
print("license_id:", license_id)
return license_id
def get_customer_details(self, project_id):
"""
@Summary: Get Customer/Project details from customer table
@param project_id (string): CE project Id
@return (dictionary): Customer details from customer table.
"""
filters = {"project_id": project_id}
customer_details = self.dbutils.get_items(
table_name=RDSCustomerTable.TABLE_NAME.value,
columns_to_select=["account_id", "license_id"],
filters=filters
)
if not customer_details:
msg = "No customer found for project {}".format(project_id)
self.logger.error(msg)
raise NoCustomerException(msg)
return customer_details[0]
def is_shared_license(self, license_id):
# This function return True or False
pass
def get_project_machines_count(self, project_id, license_id):
# This function return number of used license.
count = 20
return count
def get_license_usage(self, project_id, license_id):
# This function return number of machines used project.
count = 10
return count
def get_allocate_count(self, project_id):
"""
@Summary: Get number of licenses are used by Project.
@param project_id (string): CloudEndure Project Id.
@return (int): Number of license are used in Project.
"""
# set Session get_customer_detailsfrom Redis
status = self.ce_obj.get_set_ce_obj(
project_id=project_id, redis_utils=self.redis_utils
)
print("license_id status--:", status)
if not status:
msg = "Unable to set CEproject {}".format(project_id)
self.logger.critical(msg)
raise InvalidParamException(msg, "project_id", project_id)
print("project_id------------:", project_id)
# Get license Id
license_id = self.get_license_id(project_id)
print("license_id:", license_id)
# Check license is shared
shared_flag = self.is_shared_license(license_id)
if not shared_flag:
# Get license usage
licenses_used = self.get_license_usage(project_id, license_id)
else:
# Get machine account
licenses_used = self.get_project_machines_count(
project_id, license_id
)
return licenses_used
Я пишу юнит-тест для метода get_allocate_count
, я издеваюсь над Redis, ORMRDS, Custom Exception, Logger.
Эта функция вызывает ce_obj.get_set_ce_obj
функцию, которая возвращает True/False
. Я должен смоделировать / исправить возвращаемое значение этой функции успешно.
Но когда вызов переходит к следующему вызову функции, т.е. get_license_id
, вызов переходит в реальный вызов функции и из-за неправильных входных данных. Я не могу патчить / издеваться
Ниже приведен код модульного теста:
import responses
import unittest
from unittest.mock import patch
import os
import sys
cwd_path = os.getcwd()
sys.path.append(cwd_path)
sys.path.append(cwd_path+"/../sam-cfns/code")
sys.path.append(cwd_path+"/../sam-cfns/code/billing")
from unit_tests.common.mocks.env_mock import ENV_VAR
from unit_tests.common.mocks.logger import FakeLogger
from unit_tests.common.mocks.cache_mock import RedisUtilsMock
from unit_tests.common.mocks.ormdb_mock import ORMDBUtilsMockProject
from unit_tests.common.mocks.exceptions_mock import NoCustomerExceptionMock
from unit_tests.common.mocks.exceptions_mock import BillingExceptionMock
from unit_tests.common.mocks.exceptions_mock import InvalidParamExceptionMock
from unit_tests.common.mocks.api_responses import mock_response
from unit_tests.common.examples import ce_machines_data
from unit_tests.common.examples import ce_license_data
from unit_tests.common.examples import ce_data
class BillingTest(unittest.TestCase):
""" Billing TEST class drive from UnitTest """
@patch("billing-collect-project-license.Logger", FakeLogger)
@patch("os.environ", ENV_VAR)
@patch("billing-collect-project-license.RedisUtils", RedisUtilsMock)
@patch("billing-collect-project-license.ORMDBUtils", ORMDBUtilsMockProject)
@patch("exceptions.NoCustomerException", NoCustomerExceptionMock)
@patch("billing.billing_exceptions.BillingException", BillingExceptionMock)
@patch("billing.billing_exceptions.InvalidParamException", InvalidParamExceptionMock)
def __init__(self, *args, **kwargs):
"""Initialization"""
super(BillingTest, self).__init__(*args, **kwargs)
billing_collect_project_license_module = (
__import__("cbr-billing-collect-project-license")
)
self.licenses_stats_obj = (
billing_collect_project_license_module.LicenseStatistics()
)
class BillingCollectProjectLicense(BillingTest):
"""Login Unit Test Cases"""
def __init__(self, *args, **kwargs):
"""Initialization"""
super(BillingCollectProjectLicense, self).__init__(*args, **kwargs)
def setUp(self):
"""Setup for all Test Cases."""
pass
#@patch("billing.cbr-billing-collect-project-license.LicenseStatistics."
# "get_project_machines_count")
#@patch("billing.cbr-billing-collect-project-license.LicenseStatistics."
# "get_customer_details")
#@patch("billing.cbr-billing-collect-project-license.LicenseStatistics.get_license_id")
@patch("billing.cbr-billing-collect-project-license.LicenseStatistics.get_license_id")
@patch("cbr.ce.CloudEndure.get_set_ce_obj")
def test_get_allocate_count(self, get_set_ce_obj_mock, get_license_id_mock):
project_id = ce_data.CE_PROJECT_ID
license_id = ce_license_data.LICENSE_ID
get_set_ce_obj_mock.return_value = True
get_license_id_mock.return_value = license_id
# LicenseStatistics_mock.return_value.get_license_id.return_value = license_id
#get_license_id_mock.return_value = license_id
# self.licenses_stats_obj.get_license_id = get_license_id_mock
get_customer_details_mock = {"license_id": license_id}
# is_shared_license_mock.return_value = True
# get_project_machines_count_mock.return_value = 20
resp = self.licenses_stats_obj.get_allocate_count(
project_id
)
self.assertEqual(resp, 20)
if __name__ == '__main__':
unittest.main()
Я не могу исправить функцию get_license_id
из того же класса. Это фактически вызывает функцию get_license_id
и не работает.
Я хочу смоделировать возвращаемое значение функции get_license_id
.
Кто-нибудь поможет мне?
Спасибо.