Установите переменные окружения с помощью os.environ перед импортом модуля dag_generator в тесте - PullRequest
0 голосов
/ 06 февраля 2020

Я (к сожалению, для меня и мира, в котором мы живем) все еще использую Python 2.7.

Я пытаюсь проверить функцию в модуле. Функция должна загружать определенные переменные среды. Чтобы избежать ошибки из-за невозможности загрузить их (поскольку они определяются с помощью другого процесса), перед запуском моего теста я обязательно предоставлю фиктивные значения для нужных мне переменных env. Я делаю что-то вроде:

os.environ['gcs_bucket'] = 'test_bucket'
os.environ['gcp_region'] = 'test_region'
os.environ['gcp_project'] = 'test_nonlive'
os.environ['gce_zone'] = 'test_zone'
os.environ['firewall_rules_tags'] = 'a,b'
os.environ['subnetwork_uri'] = 'https://test.com'
os.environ['service_account'] = 'test_service_account'

Проблема в том, что если я запустил приведенный выше код после импорта модуля, который я использую в своем тесте:

from orchestrator_template.dag_generator import create_dag

Тогда тест не пройден Говоря, что переменных среды нет.

Если, однако, я сначала их устанавливаю, а затем загружаю модуль, все работает нормально.

Очевидно, что отсутствие импорта в верхней части моего модуля тестирования беспокоит меня, так как это плохая практика. Кроме того, я не могу добавить этот код sh, так как он нарушает методы стилей.

Вопрос : как мне загрузить их после импорта модуля? Что еще более важно, есть ли способ реорганизовать dag_generator.py так, чтобы он не работал так?

Полный тестовый модуль следует (wip) ниже:

import os
from datetime import datetime

from mock import patch

from orchestrator_template.plugins.autoscale_dataproc_operator import DataprocCompCreateClusterOperator
from orchestrator_template.root import root_dir
from orchestrator_template.utils.config_utils import generate_opco_configurations


def test_generate_opco_configurations():

    # Given
    expected_job = {
        'airflow_start_date': '2019-11-01',
        'cluster_config': {
            'cluster_properties': {
                'core:fs.gs.implicit.dir.repair.enable': 'false',
                'spark:spark.Comp.event.repo': 'hdfs',
                'spark:spark.Comp.metrics.repo': 'hdfs'
            },
            'image': 'blue-dp-img-20191203-105432-21',
            'init_action': [
                'gs://europe-west1-composer-dev-f2efa347-bucket/dags/orchestrator_template/scripts/init-action.sh'
            ],
            'labels': {'use_case': 'fixed'},
            'master_config': {'master_machine_type': 'n1-standard-8'},
            'worker_config': {
                'num_preemptible_workers': 1,
                'num_workers': 4,
                'worker_machine_type': 'n1-standard-32'}
        },
        'default_yaml_path': 'scripts/properties/red_agent',
        'email_on_failure': 'False',
        'email_on_retry': 'False',
        'name': 'FixedEventsDaily',
        'opco': 'it',
        'tasks': [{
            'args': [
                '--nodes=FixedLineMDCEventModel,TTOpeningEvent,TTClosingEvent,FixedEventsWideIT',
                '--nodes=FixedJourneyChurnPrediction',
                '--opco_code=it',
                '--start_date=20191101',
                '--end_date=20191101'],
            'task': 'FixedEventsDaily'}]}

    # When
    jobs = generate_opco_configurations(os.path.join(root_dir, "test/configs"))

    # Then
    assert jobs.__len__() == 1
    assert jobs[0] == expected_job

# NOTICE! : Don't change order of execution. The below env.variables must be set prior importing `create_dag`
def set_default_test_vars():
    os.environ['gcs_bucket'] = 'test_bucket'
    os.environ['gcp_region'] = 'test_region'
    os.environ['gcp_project'] = 'test_nonlive'
    os.environ['gce_zone'] = 'test_zone'
    os.environ['firewall_rules_tags'] = 'a,b'
    os.environ['subnetwork_uri'] = 'https://test.com'
    os.environ['opco'] = 'it'
    os.environ['service_account'] = 'test_service_account'


set_default_test_vars()
from orchestrator_template.dag_generator import create_dag


@patch("orchestrator_template.dag_generator.list_files_from_gcp")
def test_initialization_action_property_is_none_if_empty(mock_list_files_from_gcp):
    # Given
    dag_id = 'job_{}_{}'.format("TestDag", '1')
    default_args = {
        'owner': 'airflow',
        'start_date': datetime.strptime('2019-11-01', "%Y-%m-%d"),
        'email_on_failure': "False",
        'email_on_retry': "False",
        'project_id': "vf-dev-ca-live",
        'catchup': 'False'
    }

    # When
    mock_list_files_from_gcp.return_value = []
    jobs = generate_opco_configurations(os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs"))
    dag = create_dag(dag_id=dag_id, job=jobs[0], default_args=default_args)

    # Then
    operator = dag.task_dict["TestJobDaily"]  # type: DataprocCompCreateClusterOperator
    assert operator.init_actions_uris.__len__() is not 0

Пока dag_generator :

import logging
import os
import sys
from datetime import datetime
from functools import reduce

import airflow
from airflow.contrib.hooks.gcs_hook import GoogleCloudStorageHook
from airflow.contrib.operators.dataproc_operator import DataprocClusterDeleteOperator, \
    DataProcPySparkOperator
from airflow.operators.python_operator import PythonOperator
from airflow.utils.trigger_rule import TriggerRule

from orchestrator_template.plugins.autoscale_dataproc_operator import DataprocCompCreateClusterOperator
from orchestrator_template.plugins.cost_calculator_operator import CostCalculatorOperator
from orchestrator_template.plugins.stackdriver_logger_operator import StackdriverLoggerOperator
from orchestrator_template.utils.config_utils import generate_opco_configurations
from orchestrator_template.utils.config_utils import retrieve_yaml_path
from orchestrator_template.utils.config_utils import stringify_dict_values, generate_uid_paths

sys.path.insert(0, os.path.join(os.path.abspath(os.path.dirname(__file__)), os.pardir))

bucket_name = os.environ['gcs_bucket']
region = os.environ['gcp_region']
project_id = os.environ['gcp_project']
zone = os.environ['gce_zone']
firewall_rules_tags = os.environ['firewall_rules_tags'].split(',')
sub_network_uri = os.environ['subnetwork_uri']
global_opco = os.environ['opco']
service_account = os.environ['service_account']

version = 'latest'
dag_version = '1'
CLUSTER_NAME_PREFIX_LENGTH = 36

# Assert the mandatory Airflow Variables
assert global_opco, "Airflow variable [opco] needs to be defined."


def list_files_from_gcp(prefix, file_type):
    # type: (str, str) -> list
    files = []
    hook = GoogleCloudStorageHook()
    common_prefix = os.path.join("cloud-analytics", "code", "red_agent", version)
    list_of_files = hook.list(bucket_name, prefix="{}/{}".format(common_prefix, prefix))
    for name in list_of_files:
        if name.endswith(file_type):
            _file = "gs:" + os.path.join('//', bucket_name, name)
            files.append(_file)
    return files


def log_metric_event_failure():
    logging.error("Failed to copy metrics and events to bucket.")


def create_dag(dag_id, job, default_args):
    # type: (str, str, dict) -> airflow.DAG
    with airflow.DAG(
            dag_id,
            default_args=default_args,
            catchup=False,
            schedule_interval=job.get('schedule_interval', None)) as dag:
        # Google imposes cluster names to follow the pattern (?:[a-z](?:[-a-z0-9]{0,49}[a-z0-9])?)

        cluster_name = job['name'].lower()[:CLUSTER_NAME_PREFIX_LENGTH] + \
                       '-{{execution_date.strftime("%Y%m%d%H%M%S")}}'
        job_cluster_config_master = job['cluster_config']['master_config']
        job_cluster_config_workers = job['cluster_config']['worker_config']
        master_machine_type = job_cluster_config_master['master_machine_type']
        worker_machine_type = job_cluster_config_workers['worker_machine_type']
        num_workers = job_cluster_config_workers['num_workers']
        initialization_action = job['cluster_config'].get('init_action', None)

        t_create_dataproc_cluster = DataprocCompCreateClusterOperator(
            task_id='create_dataproc_cluster',
            cluster_name=cluster_name,
            project_id=project_id,
            service_account=service_account,
            master_machine_type=master_machine_type,
            worker_machine_type=worker_machine_type,
            num_workers=num_workers,
            num_preemptible_workers=job_cluster_config_workers.get('num_preemptible_workers', 0),
            custom_image=job['cluster_config']['image'],
            internal_ip_only=True,
            region=region,
            zone=zone,
            subnetwork_uri=sub_network_uri,
            tags=firewall_rules_tags,
            autoscaling_policy=job['cluster_config'].get('autoscaling_policy', None),
            properties=stringify_dict_values(job['cluster_config'].get('cluster_properties', {})),
            labels=stringify_dict_values(job['cluster_config'].get('labels', {})),
            enable_http_port_access=job['cluster_config'].get('enable_http_port_access', True),
            init_actions_uris=initialization_action,
            enable_optional_components=True  # Enables CONDA path for imports.
        )

        t_run_red_agent_jobs = [DataProcPySparkOperator(
            task_id=task['task'],
            cluster_name=cluster_name,
            main=list(list_files_from_gcp("red_agent/common", "redagent_main.py"))[0],
            pyfiles=list(list_files_from_gcp("artifacts", "zip")),
            dataproc_pyspark_jars=list(list_files_from_gcp("artifacts", "jar")),
            files=list(list_files_from_gcp(retrieve_yaml_path(job), "yaml")),
            dataproc_pyspark_properties=task.get("spark_properties", {}),
            arguments=[task['args']],
            region=region
        ) for task in job['tasks']]

        opco = job.get('opco', global_opco)
        hdfs_path = os.path.join('hdfs://', cluster_name + '-m:8020', 'app')

        t_generate_paths = PythonOperator(
            task_id='generate_paths',
            python_callable=generate_uid_paths,
            provide_context=True,
            templates_dict={
                'uid': "{{ dag_run.conf['uid'] }}"
            },
            op_kwargs={'job': job, 'opco': opco, 'bucket_name': bucket_name},
            dag=dag,
        )

        copy_events_metrics_to_gcs = DataProcPySparkOperator(
            task_id='copy_events_metrics_to_gcs',
            cluster_name=cluster_name,
            main=os.path.join(os.path.dirname(os.path.abspath(__file__)), "utils/copy_files.py"),
            dataproc_pyspark_jars=list(list_files_from_gcp("artifacts", "jar")),
            arguments=['--source={hdfs_path}'.format(hdfs_path=hdfs_path),
                       "--destination=" + os.path.join('gs://', bucket_name,
                                                       "{{ ti.xcom_pull(task_ids='generate_paths')[2] }}"),
                       '--opco=' + opco],
            region=region
        )

        t_push_events_and_metrics_to_stackdriver = StackdriverLoggerOperator(
            task_id='push_events_and_metrics_to_stackdriver',
            bucket=bucket_name,
            provide_context=True,
            templates_dict={
                'path': "{{ ti.xcom_pull(task_ids='generate_paths')[1]}}",
                'job_id': "{{ ti.xcom_pull(task_ids='generate_paths')[0] }}",
                'run_id': "{{ ti.xcom_pull(task_ids='generate_paths')[2] }}"
            },
            project_id=project_id,
            region=region
        )

        t_delete_dataproc_cluster = DataprocClusterDeleteOperator(
            task_id='delete_dataproc_cluster',
            project_id=project_id,
            cluster_name=cluster_name,
            trigger_rule=TriggerRule.ALL_DONE,
            region=region
        )

        default_disk_type = 'pd-standard'
        default_disk_size = 500

        t_cost_calculator = CostCalculatorOperator(
            task_id='cost_calculator',
            master_machine_type=master_machine_type,
            worker_machine_type=worker_machine_type,
            master_disk_type=job_cluster_config_master.get('master_disk_type', default_disk_type),
            master_disk_size=job_cluster_config_master.get('master_disk_size', default_disk_size),
            worker_disk_type=job_cluster_config_workers.get('worker_disk_type', default_disk_type),
            worker_disk_size=job_cluster_config_workers.get('worker_disk_size', default_disk_size),
            num_workers=num_workers,
            log_name='test_dag_cost'
        )

        assert t_run_red_agent_jobs

        t_generate_paths >> t_create_dataproc_cluster >> t_run_red_agent_jobs[0]
        reduce(lambda a, b: a >> b, t_run_red_agent_jobs)
        t_run_red_agent_jobs[-1] >> copy_events_metrics_to_gcs \
            >> t_push_events_and_metrics_to_stackdriver >> t_delete_dataproc_cluster >> t_cost_calculator
    return dag



jobs = generate_opco_configurations(os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs"))

for job in jobs:
    dag_id = 'job_{}_{}'.format(job['name'], dag_version)

    default_args = {
        'owner': 'airflow',
        'start_date': datetime.strptime(job['airflow_start_date'], "%Y-%m-%d"),
        'email_on_failure': job['email_on_failure'],
        'email_on_retry': job['email_on_retry'],
        'project_id': project_id,
        'catchup': job.get('catchup', 'False')
    }
    globals()[dag_id] = create_dag(dag_id=dag_id, job=job, default_args=default_args)

...